In [1]:
import numpy as np
import pickle
import os
import json
import logging
import torch
import shutil

In [2]:
def save_txt(filename, data):
    with open(filename, 'a') as f:
        f.write(data)
    f.close()

In [4]:
# 示例：将字符串 "Hello, World!" 追加到文件 "example.txt" 中
save_txt("example.txt", "Hello, World!\n")

In [3]:
import os
import tempfile

def test_save_txt():
    # 创建一个临时文件
    temp_file = tempfile.NamedTemporaryFile(delete=False)
    temp_file_path = temp_file.name
    temp_file.close()

    # 定义要写入文件的数据
    data_to_write = "This is a test.\n"

    # 调用 save_txt 函数
    save_txt(temp_file_path, data_to_write)

    # 读取文件内容以验证数据
    with open(temp_file_path, 'r') as f:
        file_content = f.read()

    # 断言文件内容与预期相符
    assert file_content == data_to_write, f"Expected '{data_to_write}', but got '{file_content}'"

    # 清理临时文件
    os.remove(temp_file_path)

# 运行测试用例
test_save_txt()

In [7]:
def list_to_csv(li):
    result = ''
    for i in li:
        result += '{}, '.format(i)
    return result.rstrip(', ')

In [8]:
def test_list_to_csv():
    # 定义一个测试列表
    test_list = [1, 2, 3, 4, 5]
    
    # 调用 list_to_csv 函数，将测试列表转换为 CSV 格式的字符串
    csv_string = list_to_csv(test_list)
    
    # 验证结果字符串是否与预期的 CSV 格式字符串相匹配
    expected_csv_string = "1, 2, 3, 4, 5"
    assert csv_string == expected_csv_string, f"Expected '{expected_csv_string}', but got '{csv_string}'"
    
    # 测试空列表的情况
    empty_list = []
    csv_string_empty = list_to_csv(empty_list)
    assert csv_string_empty == "", f"Expected empty string for empty list, but got '{csv_string_empty}'"

# 调用测试函数
test_list_to_csv()

In [9]:
def regeneralize(result, mean, std):  # 输入的result是个array
    mean_matrix = mean
    std_matrix = std
    result_real_scale = (result * std_matrix) + mean_matrix
    return result_real_scale

In [10]:
import numpy as np

def test_regeneralize():
    # 假设我们有一个已经标准化的数据集
    result = np.array([-1.0, 0.0, 1.0]) # 标准化后的数据
    mean = np.array([0.0]) # 原始数据的均值
    std = np.array([1.0]) # 原始数据的标准差

    # 调用regeneralize函数
    regeneralized_result = regeneralize(result, mean, std)

    # 验证结果
    expected_result = np.array([-1.0, 0.0, 1.0]) # 预期的结果
    assert np.allclose(regeneralized_result, expected_result), "Test failed: regeneralized result does not match expected result."
    print("Test passed: regeneralized result matches expected result.")

# 运行测试
test_regeneralize()

Test passed: regeneralized result matches expected result.


In [23]:
def get_file_list(prefix, folder):
    result = []
    file_list = os.listdir(folder)
    for file in file_list:
        if file.endswith(prefix):
            result.append(os.path.join(folder, file))
    return result


In [24]:
import os

def test_get_file_list():
    # 定义测试文件夹和后缀
    folder = './data/' # 请确保这是实际的文件夹路径
    prefix = '.csv'
    
    # 调用函数获取文件列表
    file_list = get_file_list(prefix, folder)
    
    # 预期结果：文件夹中所有以 .csv 结尾的文件
    expected_files = [
        os.path.join(folder, 'vertical_all_A1.csv'),
        os.path.join(folder, 'vertical_all_A2.csv'),
        os.path.join(folder, 'vertical_all_A3.csv'),
        os.path.join(folder, 'vertical_all_A4.csv'),
        os.path.join(folder, 'vertical_all_A5.csv'),
        os.path.join(folder, 'vertical_all_A6.csv')
    ]
    
    # 断言函数返回的文件列表与预期结果相同
    assert file_list == expected_files, f"Expected {expected_files}, but got {file_list}"

# 运行测试
test_get_file_list()

是的，这两个功能在某种程度上是相似的，但它们的实现方式和目的有所不同。

1. **使用 `file.endswith(prefix)` 的功能**：
   - 这个功能的目的是找到所有以特定后缀结尾的文件。
   - 它使用 `endswith` 方法来检查文件名是否以指定的后缀结尾。
   - 这个方法适用于查找文件类型，例如所有的 `.txt` 文件。

2. **使用 `re.match(f'^{prefix}', file)` 的功能**：
   - 这个功能的目的是找到所有以特定前缀开头的文件。
   - 它使用正则表达式来检查文件名是否以指定的前缀开头。
   - 这个方法适用于查找以特定字符串开头的文件，例如所有以 `vertical_all_A` 开头的文件。

虽然这两个功能在某种程度上是相似的，但它们的使用场景和实现方式有所不同。`endswith` 方法是一个简单的字符串方法，用于检查字符串是否以指定的后缀结尾。而 `re.match` 是一个正则表达式函数，提供了更强大的模式匹配功能，可以用于检查字符串是否以指定的前缀开头。

在你的情况下，由于你需要查找以特定前缀开头的文件，使用正则表达式是更合适的选择。

In [20]:
import os
import re

def get_file_list(prefix, folder):
    result = []
    file_list = os.listdir(folder)
    for file in file_list:
        if re.match(f'^{prefix}', file):
            result.append(os.path.join(folder, file))
    return result

In [21]:
import os

def test_get_file_list():
    # 定义测试文件夹和前缀
    folder = './data/' # 请替换为实际的文件夹路径
    prefix = 'vertical_all_A'
    print(os.listdir(folder))
    # 调用函数获取文件列表
    file_list = get_file_list(prefix, folder)
    
    # 预期结果：文件夹中所有以 'vertical_all_A' 开头的 .csv 文件
    expected_files = [
        os.path.join(folder, 'vertical_all_A1.csv'),
        os.path.join(folder, 'vertical_all_A2.csv'),
        os.path.join(folder, 'vertical_all_A3.csv'),
        os.path.join(folder, 'vertical_all_A4.csv'),
        os.path.join(folder, 'vertical_all_A5.csv'),
        os.path.join(folder, 'vertical_all_A6.csv')
    ]
    
    # 断言函数返回的文件列表与预期结果相同
    assert file_list == expected_files, f"Expected {expected_files}, but got {file_list}"

# 运行测试
test_get_file_list()

['vertical_all_A1.csv', 'vertical_all_A2.csv', 'vertical_all_A3.csv', 'vertical_all_A4.csv', 'vertical_all_A5.csv', 'vertical_all_A6.csv']


In [25]:
def save_var(v, filename):
    f = open(filename, 'wb')
    pickle.dump(v, f)
    f.close()
    return filename

In [26]:
def test_save_var():
    # 创建一个简单的字典对象
    test_dict = {'key': 'value', 'number': 42}
    
    # 使用 save_var 函数保存字典到文件
    filename = 'test_dict.pkl'
    save_var(test_dict, filename)
    
    # 读取文件并验证内容
    with open(filename, 'rb') as f:
        loaded_dict = pickle.load(f)
    
    # 验证加载的字典与原始字典相同
    assert loaded_dict == test_dict, "Loaded dictionary does not match the original one."
    
    print("Test passed: The dictionary was saved and loaded correctly.")

# 运行测试
test_save_var()

Test passed: The dictionary was saved and loaded correctly.


In [38]:
def shrink(data, size):
    i = 0
    while i+size < data.shape[1]:
        yield data[:, i:i+size, :].mean(axis=1)
        i += size

In [39]:

# 创建一个假的数据集，形状为(3, 10, 2)
data = np.array([
    [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]],
    [[21, 22], [23, 24], [25, 26], [27, 28], [29, 30], [31, 32], [33, 34], [35, 36], [37, 38], [39, 40]],
    [[41, 42], [43, 44], [45, 46], [47, 48], [49, 50], [51, 52], [53, 54], [55, 56], [57, 58], [59, 60]]
])

# 使用shrink函数处理数据，块大小为2
result = list(shrink(data, 2))
print(result)


[array([[ 2.,  3.],
       [22., 23.],
       [42., 43.]]), array([[ 6.,  7.],
       [26., 27.],
       [46., 47.]]), array([[10., 11.],
       [30., 31.],
       [50., 51.]]), array([[14., 15.],
       [34., 35.],
       [54., 55.]])]


In [40]:
result

[array([[ 2.,  3.],
        [22., 23.],
        [42., 43.]]),
 array([[ 6.,  7.],
        [26., 27.],
        [46., 47.]]),
 array([[10., 11.],
        [30., 31.],
        [50., 51.]]),
 array([[14., 15.],
        [34., 35.],
        [54., 55.]])]

In [41]:
class Record:
    
    def __init__(self):
        self.data = []
        self.mean = []
        self.count = 0
        self.reset()

    def reset(self):
        self.data = []
        self.mean = []
        self.count = 0

    def update(self, val, n=1):
        self.data.append(val)
        if type(val) == list:
            self.mean.append(np.mean(val))
        else:
            self.mean.append(val)
        self.count += n
    
    def get_latest(self, mean=True):
        if mean:
            return self.mean[-1]
        else:
            return self.data[-1]

    def delta(self):
        return abs(self.mean[-1] - self.mean[-2])

    def bigger(self):
        return self.mean[-1] - self.mean[-2] > 0

    def check(self, val):
        return (self.mean[-1] - np.mean(val)) < 0


In [42]:
import numpy as np

# 创建一个Record实例
record = Record()

# 更新实例，添加一个数值
record.update(10)

# 更新实例，添加一个列表
record.update([20, 30, 40])

# 获取最新的平均值
latest_mean = record.get_latest(mean=True)
print(f"最新的平均值: {latest_mean}")

# 获取最新的数据
latest_data = record.get_latest(mean=False)
print(f"最新的数据: {latest_data}")

# 计算最新的平均值与前一个平均值之间的差值
delta = record.delta()
print(f"最新的平均值与前一个平均值之间的差值: {delta}")

# 检查最新的平均值是否大于前一个平均值
is_bigger = record.bigger()
print(f"最新的平均值是否大于前一个平均值: {is_bigger}")

# 检查传入的值与最新的平均值之间的关系
check_result = record.check([50, 60])
print(f"传入的值与最新的平均值之间的关系: {check_result}")

最新的平均值: 30.0
最新的数据: [20, 30, 40]
最新的平均值与前一个平均值之间的差值: 20.0
最新的平均值是否大于前一个平均值: True
传入的值与最新的平均值之间的关系: True


In [45]:

class Params():
    """Class that loads hyperparameters from a json file.

    Example:
    ```
    params = Params(json_path)
    print(params.learning_rate)
    params.learning_rate = 0.5  # change the value of learning_rate in params
    ```
    """

    def __init__(self, json_path):
        with open(json_path) as f:
            params = json.load(f)
            self.__dict__.update(params)

    def save(self, json_path):
        with open(json_path, 'w') as f:
            json.dump(self.__dict__, f, indent=4)
            
    def update(self, json_path):
        """Loads parameters from json file"""
        with open(json_path) as f:
            params = json.load(f)
            self.__dict__.update(params)

    @property
    def dict(self):
        """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
        return self.__dict__

In [46]:
import json

# 加载超参数
params = Params('hyperparameters.json')
print("Initial learning rate:", params.learning_rate)

# 修改超参数
params.learning_rate = 0.02
params.batch_size = 64
params.epochs = 200

# 保存修改后的超参数
params.save('updated_hyperparameters.json')

# 加载更新后的超参数
updated_params = Params('updated_hyperparameters.json')
print("Updated learning rate:", updated_params.learning_rate)
print("Updated batch size:", updated_params.batch_size)
print("Updated epochs:", updated_params.epochs)

Initial learning rate: 0.01
Updated learning rate: 0.02
Updated batch size: 64
Updated epochs: 200


In [47]:

class RunningAverage():
    """A simple class that maintains the running average of a quantity
    
    Example:
    ```
    loss_avg = RunningAverage()
    loss_avg.update(2)
    loss_avg.update(4)
    loss_avg() = 3
    ```
    """
    def __init__(self):
        self.steps = 0
        self.total = 0
    
    def update(self, val):
        self.total += val
        self.steps += 1
    
    def __call__(self):
        return self.total/float(self.steps)
        

In [48]:
loss_avg = RunningAverage()
loss_avg.update(2)
loss_avg.update(4)
print(loss_avg())

3.0


In [49]:
   
def set_logger(log_path):
    """Set the logger to log info in terminal and file `log_path`.

    In general, it is useful to have a logger so that every output to the terminal is saved
    in a permanent file. Here we save it to `model_dir/train.log`.

    Example:
    ```
    logging.info("Starting training...")
    ```

    Args:
        log_path: (string) where to log
    """
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        # Logging to a file
        file_handler = logging.FileHandler(log_path)
        file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
        logger.addHandler(file_handler)

        # Logging to console
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(logging.Formatter('%(message)s'))
        logger.addHandler(stream_handler)


In [51]:
import logging
import os
import tempfile

# 创建一个临时文件路径
temp_log_file = tempfile.NamedTemporaryFile(delete=False).name

# 调用 set_logger 函数
set_logger(temp_log_file)

# 记录一条日志消息
logging.info("This is a test log message.")

# 检查日志文件
with open(temp_log_file, 'r') as file:
    log_content = file.read()
    print(f"Log content: {log_content}")

# 清理临时文件
os.remove(temp_log_file)

INFO:root:This is a test log message.


Log content: 


In [52]:

def save_dict_to_json(d, json_path):
    """Saves dict of floats in json file

    Args:
        d: (dict) of float-castable values (np.float, int, float, etc.)
        json_path: (string) path to json file
    """
    with open(json_path, 'w') as f:
        # We need to convert the values to float for json (it doesn't accept np.array, np.float, )
        d = {k: float(v) for k, v in d.items()}
        json.dump(d, f, indent=4)


In [53]:
# 测试用例
def test_save_dict_to_json():
    # 创建一个字典
    test_dict = {
        "key1": 1,
        "key2": 2.5,
        "key3": 3.0,
        "key4": 4.0
    }
    
    # 指定保存的 JSON 文件路径
    json_path = "test_output.json"
    
    # 调用函数保存字典到 JSON 文件
    save_dict_to_json(test_dict, json_path)
    
    # 检查文件是否存在
    assert os.path.exists(json_path), "JSON file not created"
    
    # 读取并检查 JSON 文件内容
    with open(json_path, 'r') as f:
        loaded_dict = json.load(f)
        assert loaded_dict == test_dict, "Loaded data does not match original data"
    
    # 删除测试文件
    os.remove(json_path)

# 运行测试用例
test_save_dict_to_json()

In [54]:

def save_checkpoint(state, is_best, checkpoint):
    """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves
    checkpoint + 'best.pth.tar'

    Args:
        state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict
        is_best: (bool) True if it is the best model seen till now
        checkpoint: (string) folder where parameters are to be saved
    """
    filepath = os.path.join(checkpoint, 'last.pth.tar')
    if not os.path.exists(checkpoint):
        print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint))
        os.mkdir(checkpoint)
    else:
        print("Checkpoint Directory exists! ")
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))


In [56]:
import os
import torch
import shutil

class TestSaveCheckpoint:
    def setUp(self):
        self.checkpoint_dir = "test_checkpoint"
        # 创建一个简单的线性层模型
        model = torch.nn.Linear(10, 1)
        self.state = {
            "epoch": 10,
            "state_dict": model.state_dict(),
            "optimizer_state_dict": torch.optim.Adam(model.parameters()).state_dict()
        }
        self.is_best = True

    def tearDown(self):
        if os.path.exists(self.checkpoint_dir):
            shutil.rmtree(self.checkpoint_dir)

    def test_save_checkpoint(self):
        save_checkpoint(self.state, self.is_best, self.checkpoint_dir)
        assert os.path.exists(os.path.join(self.checkpoint_dir, 'last.pth.tar')), "last.pth.tar not found"
        assert os.path.exists(os.path.join(self.checkpoint_dir, 'best.pth.tar')), "best.pth.tar not found"

# 创建一个实例并运行测试
test = TestSaveCheckpoint()
test.setUp()
test.test_save_checkpoint()
test.tearDown()

Checkpoint Directory does not exist! Making directory test_checkpoint


In [57]:

def load_checkpoint(checkpoint, model, optimizer=None):
    """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of
    optimizer assuming it is present in checkpoint.

    Args:
        checkpoint: (string) filename which needs to be loaded
        model: (torch.nn.Module) model for which the parameters are loaded
        optimizer: (torch.optim) optional: resume optimizer from checkpoint
    """
    if not os.path.exists(checkpoint):
        raise("File doesn't exist {}".format(checkpoint))
    checkpoint = torch.load(checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    if optimizer:
        optimizer.load_state_dict(checkpoint['optim_dict'])

    return checkpoint

In [None]:
import os
import torch
import shutil

class TestLoadCheckpoint:
    def setUp(self):
        self.checkpoint_dir = "test_checkpoint"
        self.model = torch.nn.Linear(10, 1)
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.state = {
            "epoch": 10,
            "state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict()
        }
        self.is_best = True
        save_checkpoint(self.state, self.is_best, self.checkpoint_dir)

    def tearDown(self):
        if os.path.exists(self.checkpoint_dir):
            shutil.rmtree(self.checkpoint_dir)

    def test_load_checkpoint(self):
        # 加载检查点
        loaded_state = load_checkpoint(os.path.join(self.checkpoint_dir, 'last.pth.tar'), self.model, self.optimizer)
        # 检查模型和优化器的状态是否正确加载
        assert self.state['epoch'] == loaded_state['epoch'], "Epoch mismatch"
        assert self.state['state_dict'] == loaded_state['state_dict'], "Model state_dict mismatch"
        assert self.state['optimizer_state_dict'] == loaded_state['optimizer_state_dict'], "Optimizer state_dict mismatch"

# 创建一个实例并运行测试
test = TestLoadCheckpoint()
test.setUp()
test.test_load_checkpoint()
test.tearDown()