In [10]:
use_gpu = True
use_ramdon_split = False
use_dataparallel = True

In [11]:
# 详细GPU诊断
import torch
import os
import subprocess
import sys

# 添加路径以导入utils模块
sys.path.insert(0, '..')

print("=== 详细GPU诊断 ===")
print(f"PyTorch版本: {torch.__version__}")

# 安全地检查CUDA可用性
print("正在检查CUDA可用性...")
try:
    cuda_available = torch.cuda.is_available()
    print(f"CUDA可用: {cuda_available}")
    
    if cuda_available:
        try:
            device_count = torch.cuda.device_count()
            print(f"CUDA设备数量: {device_count}")
            
            if device_count > 0:
                print(f"当前CUDA设备: {torch.cuda.current_device()}")
                for i in range(device_count):
                    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
                    props = torch.cuda.get_device_properties(i)
                    print(f"  总内存: {props.total_memory / 1024**3:.1f} GB")
                    print(f"  计算能力: {props.major}.{props.minor}")
                    print(f"  多处理器数量: {props.multi_processor_count}")
            else:
                print("没有检测到CUDA设备")
        except Exception as e:
            print(f"获取CUDA设备信息时出错: {e}")
    else:
        print("CUDA不可用")
        
except Exception as e:
    print(f"检查CUDA时出错: {e}")
    print("可能的原因:")
    print("1. PyTorch安装有问题")
    print("2. CUDA驱动未安装")
    print("3. PyTorch版本与CUDA不兼容")

# 检查环境变量
print(f"\nCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', '未设置')}")

# 尝试使用nvidia-smi检查
print("\n=== 使用nvidia-smi检查系统GPU ===")
try:
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)
    if result.returncode == 0:
        print("nvidia-smi 输出:")
        print(result.stdout)
    else:
        print(f"nvidia-smi 错误: {result.stderr}")
except Exception as e:
    print(f"nvidia-smi 执行失败: {e}")
    print("可能的原因:")
    print("1. nvidia-smi 未安装")
    print("2. NVIDIA驱动未安装")
    print("3. 没有NVIDIA GPU")

print("==================")


=== 详细GPU诊断 ===
PyTorch版本: 2.7.1+cu118
正在检查CUDA可用性...
CUDA可用: True
CUDA设备数量: 1
当前CUDA设备: 0
GPU 0: NVIDIA GeForce RTX 2080
  总内存: 8.0 GB
  计算能力: 7.5
  多处理器数量: 46

CUDA_VISIBLE_DEVICES: 0

=== 使用nvidia-smi检查系统GPU ===
nvidia-smi 输出:
Wed Oct  8 17:19:53 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 471.41       Driver Version: 471.41       CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:2D:00.0 Off |                  N/A |
| 40%   41C    P8     6W / 225W |   1454MiB /  8192MiB |     13%      Default |
|                               |                      |                  N/A |
+-----------------

In [12]:
# PyTorch安装检查
print("=== PyTorch安装检查 ===")
try:
    import torch
    print(f"✓ PyTorch已安装，版本: {torch.__version__}")
    
    # 检查是否有CUDA支持
    if torch.cuda.is_available():
        print("✓ PyTorch支持CUDA")
    else:
        print("✗ PyTorch不支持CUDA或CUDA不可用")
        
    # 检查torch.version是否存在（在PyTorch 2.7.1中可能不存在）
    if hasattr(torch, 'version'):
        print("✓ torch.version 模块存在")
        if hasattr(torch.version, 'cuda'):
            print(f"✓ CUDA版本: {torch.version.cuda}")
        else:
            print("✗ torch.version.cuda 不存在")
    else:
        print("ℹ torch.version 模块不存在（这在PyTorch 2.7.1中是正常的）")
        
except ImportError:
    print("✗ PyTorch未安装")
except Exception as e:
    print(f"✗ PyTorch检查失败: {e}")

print("==================")

# 测试GPU工具函数
print("=== 测试GPU工具函数 ===")
try:
    # 添加路径以导入utils模块
    import sys
    import os
    sys.path.insert(0, '..')
    
    from utils.gpu_tools import query_gpu, select_gpu
    
    # 查询GPU信息
    gpu_info = query_gpu()
    print(f"query_gpu() 返回结果:")
    for i, line in enumerate(gpu_info):
        print(f"  GPU {i}: {line.strip()}")
    
    # 选择GPU
    selected_gpus = select_gpu(gpu_info)
    print(f"select_gpu() 返回结果: {selected_gpus}")
    
    if selected_gpus:
        print(f"将使用GPU: {selected_gpus}")
    else:
        print("没有选择到可用的GPU")
        
except Exception as e:
    print(f"GPU工具函数测试失败: {e}")
    import traceback
    traceback.print_exc()

print("==================")


=== PyTorch安装检查 ===
✓ PyTorch已安装，版本: 2.7.1+cu118
✓ PyTorch支持CUDA
✓ torch.version 模块存在
✓ CUDA版本: 11.8
=== 测试GPU工具函数 ===
query_gpu() 返回结果:
  GPU 0: 0, NVIDIA GeForce RTX 2080, 6747 MiB
select_gpu() 返回结果: [0]
将使用GPU: [0]


In [13]:
# GPU状态检查
print("=== GPU状态检查 ===")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU数量: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  内存: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
else:
    print("没有可用的CUDA设备")
print("==================")


=== GPU状态检查 ===
CUDA可用: True
GPU数量: 1
GPU 0: NVIDIA GeForce RTX 2080
  内存: 8.0 GB


In [None]:
import os
import sys
sys.path.insert(0, '..')

if use_gpu:
    from utils.gpu_tools import *
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([ str(obj) for obj in select_gpu(query_gpu())])

import time
import datetime
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split



torch.manual_seed(42)

# 新的图像尺寸参数
# 78根K线 * 3像素宽 = 234，加6个空列 + padding4*2 = 248
IMAGE_WIDTH = 248
IMAGE_HEIGHT = 248
NUM_CHANNELS = 1  # 只使用R通道（K线数据）
NUM_KBARS = 78  # 昨日72根5分钟K线 + 今日开盘6根K线  

## load data

here we choose 1993-2001 data as our training(include validation) data, the remaining will be used in testing.

In [17]:
# 使用新的数据加载函数
# 数据格式：
# - 图像：(N, H, W, C) = (N, 248, 248, 3) - RGB格式
# - 标签：DataFrame with multiple label columns (10k_*, 5k_*)

# 设置数据目录
data_dir = "../notebooks/training_data"

# 直接实现数据加载函数
def load_training_data_simple(data_dir):
    """
    简单的训练数据加载函数
    """
    import glob
    
    print("=== 加载训练数据 ===")
    print(f"数据目录: {data_dir}")
    
    # 查找所有年份的数据文件
    image_files = glob.glob(os.path.join(data_dir, "year_*_images.dat"))
    label_files = glob.glob(os.path.join(data_dir, "year_*_labels.feather"))
    
    if not image_files or not label_files:
        print("错误: 没有找到数据文件！")
        return None, None
    
    # 提取年份
    years = []
    for file in image_files:
        year = int(os.path.basename(file).split('_')[1])
        years.append(year)
    years = sorted(years)
    
    print(f"找到年份: {years}")
    
    images_list = []
    labels_list = []
    
    for year in years:
        image_file = os.path.join(data_dir, f"year_{year}_images.dat")
        label_file = os.path.join(data_dir, f"year_{year}_labels.feather")
        
        if os.path.exists(image_file) and os.path.exists(label_file):
            print(f"加载 {year} 年数据...")
            
            # 加载图像数据
            img_data = np.memmap(image_file, dtype=np.uint8, mode='r')
            # 重塑为正确的形状 (N, 248, 248, 3)
            img_data = img_data.reshape((-1, 248, 248, 3))
            images_list.append(img_data)
            
            # 加载标签数据
            labels = pd.read_feather(label_file)
            labels_list.append(labels)
            
            print(f"  图像形状: {img_data.shape}")
            print(f"  标签形状: {labels.shape}")
        else:
            print(f"警告: {year} 年数据文件不完整，跳过")
    
    if images_list and labels_list:
        # 合并所有数据
        images = np.concatenate(images_list, axis=0)
        labels_df = pd.concat(labels_list, ignore_index=True)
        
        print(f"\n数据加载成功！")
        print(f"总图像数: {images.shape[0]}")
        print(f"总标签数: {len(labels_df)}")
        return images, labels_df
    else:
        print("错误: 没有成功加载任何数据！")
        return None, None

# 加载训练数据
images, label_df = load_training_data_simple(data_dir)

if images is not None and label_df is not None:
    print(f"\n数据加载成功！")
    print(f"图像形状: {images.shape}")
    print(f"标签形状: {label_df.shape}")
    print(f"标签列: {label_df.columns.tolist()}")
    
    # 显示年份分布
    if 'year' in label_df.columns:
        print(f"\n年份分布:")
        year_counts = label_df['year'].value_counts().sort_index()
        for year, count in year_counts.items():
            print(f"  {year}: {count} 个样本")
    
    # 显示标签统计
    print(f"\n标签统计:")
    for col in label_df.columns:
        if col.startswith(('10k_', '5k_')):
            positive_count = label_df[col].sum()
            total_count = len(label_df)
            percentage = (positive_count / total_count) * 100 if total_count > 0 else 0
            print(f"  {col}: {positive_count}/{total_count} ({percentage:.1f}%)")
    
    # 保存前两张图像为PNG文件用于检查
    import cv2
    import os
    
    # 创建输出目录
    debug_dir = "../pic/debug_images"
    os.makedirs(debug_dir, exist_ok=True)
    
    print(f"\n保存前两张图像用于检查...")
    for i in range(min(2, images.shape[0])):
        # 获取图像数据 (H, W, C)
        img = images[i]  # 形状: (248, 248, 3)
        
        # 保存RGB图像
        filename = f"sample_{i+1}_RGB.png"
        filepath = os.path.join(debug_dir, filename)
        cv2.imwrite(filepath, img)
        
        print(f"  保存: {filename}")
        print(f"    图像形状: {img.shape}")
        print(f"    数值范围: [{img.min()}, {img.max()}]")
        
        # 分别保存每个通道
        channel_names = ['R通道(K线)', 'G通道(EMA)', 'B通道(ZIGZAG)']
        for channel_idx in range(3):
            channel_img = img[:, :, channel_idx]  # 形状: (248, 248)
            
            filename = f"sample_{i+1}_{channel_names[channel_idx]}.png"
            filepath = os.path.join(debug_dir, filename)
            cv2.imwrite(filepath, channel_img)
            
            print(f"  保存: {filename}")
            print(f"    通道: {channel_names[channel_idx]}")
            print(f"    数值范围: [{channel_img.min()}, {channel_img.max()}]")
        
        # 打印标签信息
        if i < len(label_df):
            print(f"  标签信息: {label_df.iloc[i].to_dict()}")
        
        print()
    
    print(f"调试图像已保存到: {debug_dir}")
    
else:
    print("错误: 数据加载失败！")
    print(f"请检查数据目录: {data_dir}")
    print("确保存在以下格式的文件:")
    print("  - year_YYYY_images.dat")
    print("  - year_YYYY_labels.feather")

=== 加载训练数据 ===
数据目录: ../notebooks/training_data
找到年份: [2020]
加载 2020 年数据...
  图像形状: (15120, 248, 248, 3)
  标签形状: (15120, 18)

数据加载成功！
总图像数: 15120
总标签数: 15120

数据加载成功！
图像形状: (15120, 248, 248, 3)
标签形状: (15120, 18)
标签列: ['year', 'day_pair_idx', 'time_str', 'atr', '10k_4atr', '10k_2atr', '10k_1atr', '10k_0atr', '10k_-1atr', '10k_-2atr', '10k_-4atr', '5k_3atr', '5k_2atr', '5k_1atr', '5k_0atr', '5k_-1atr', '5k_-2atr', '5k_-3atr']

年份分布:
  2020: 15120 个样本

标签统计:
  10k_4atr: 1264.0/15120 (8.4%)
  10k_2atr: 2034.0/15120 (13.5%)
  10k_1atr: 3217.0/15120 (21.3%)
  10k_0atr: 3318.0/15120 (21.9%)
  10k_-1atr: 2381.0/15120 (15.7%)
  10k_-2atr: 2144.0/15120 (14.2%)
  10k_-4atr: 760.0/15120 (5.0%)
  5k_3atr: 453.0/15120 (3.0%)
  5k_2atr: 1474.0/15120 (9.7%)
  5k_1atr: 3779.0/15120 (25.0%)
  5k_0atr: 4535.0/15120 (30.0%)
  5k_-1atr: 2935.0/15120 (19.4%)
  5k_-2atr: 1211.0/15120 (8.0%)
  5k_-3atr: 733.0/15120 (4.8%)

保存前两张图像用于检查...
  保存: sample_1_RGB.png
    图像形状: (248, 248, 3)
    数值范围: [0, 255]
  保存: 

## build dataset

In [None]:
class MyDataset(Dataset):
    
    def __init__(self, img, label):
        self.img = torch.Tensor(img.copy())
        self.label = torch.Tensor(label)
        self.len = len(img)
  
    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.img[idx], self.label[idx]

In [19]:
class MyDataset(Dataset):
    """
    股票图像数据集
    只使用R通道（K线数据）进行训练
    """
    
    def __init__(self, img, label):
        """
        参数:
        img: numpy数组，形状为 (N, H, W, C) - RGB格式
        label: numpy数组，形状为 (N,) 或 (N, num_classes)
        """
        # 提取R通道（K线数据）
        if len(img.shape) == 4 and img.shape[3] == 3:
            # RGB格式，提取R通道
            img_r = img[:, :, :, 0]  # 形状: (N, H, W)
        elif len(img.shape) == 3:
            # 已经是单通道
            img_r = img
        else:
            raise ValueError(f"不支持的图像格式: {img.shape}")
        
        # 添加通道维度 (N, H, W) -> (N, 1, H, W)
        img_r = np.expand_dims(img_r, axis=1)
        
        # 转换为float32并归一化到[0,1]
        self.img = torch.FloatTensor(img_r.astype(np.float32) / 255.0)
        self.label = torch.FloatTensor(label)
        self.len = len(img_r)
        
        print(f"数据集初始化:")
        print(f"  原始图像形状: {img.shape}")
        print(f"  R通道图像形状: {self.img.shape}")
        print(f"  标签形状: {self.label.shape}")
  
    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.img[idx], self.label[idx]
# 创建二分类标签：10k_1atr及以上为涨（1），其他为跌（0）
print(f"\n=== 创建二分类标签 ===")

# 获取所有10k_标签列
label_columns = [col for col in label_df.columns if col.startswith('10k_')]
print(f"找到10k标签列: {label_columns}")

# 检查是否有10k_1atr及以上标签
target_columns = [col for col in label_columns if col in ['10k_1atr', '10k_2atr', '10k_4atr']]

if target_columns:
    print(f"使用标签列: {target_columns}")
    # 如果任何一个目标标签为1，则为涨（标签=1）
    labels = (label_df[target_columns].sum(axis=1) > 0).values.astype(int)
    print(f"二分类标签统计:")
    print(f"  涨（标签=1）: {labels.sum()} 个样本 ({labels.mean()*100:.1f}%)")
    print(f"  跌（标签=0）: {(1-labels).sum()} 个样本 ({(1-labels.mean())*100:.1f}%)")
else:
    print("错误: 没有找到10k_1atr及以上标签列！")
    labels = np.zeros(len(label_df), dtype=int)

# 数据集分割
if not use_ramdon_split:
    train_val_ratio = 0.7
    split_idx = int(images.shape[0] * 0.7)
    train_dataset = MyDataset(images[:split_idx], labels[:split_idx])
    val_dataset = MyDataset(images[split_idx:], labels[split_idx:])
    print(f"顺序分割: 训练集 {split_idx} 样本, 验证集 {len(images)-split_idx} 样本")
else:
    dataset = MyDataset(images, labels)
    train_val_ratio = 0.7
    train_dataset, val_dataset = random_split(dataset, \
        [int(dataset.len*train_val_ratio), dataset.len-int(dataset.len*train_val_ratio)], \
        generator=torch.Generator().manual_seed(42))
    del dataset
    print(f"随机分割: 训练集 {len(train_dataset)} 样本, 验证集 {len(val_dataset)} 样本")

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False, pin_memory=True)



=== 创建二分类标签 ===
找到10k标签列: ['10k_4atr', '10k_2atr', '10k_1atr', '10k_0atr', '10k_-1atr', '10k_-2atr', '10k_-4atr']
使用标签列: ['10k_4atr', '10k_2atr', '10k_1atr']
二分类标签统计:
  涨（标签=1）: 6515 个样本 (43.1%)
  跌（标签=0）: 8605 个样本 (56.9%)
数据集初始化:
  原始图像形状: (10584, 248, 248, 3)
  R通道图像形状: torch.Size([10584, 1, 248, 248])
  标签形状: torch.Size([10584])
数据集初始化:
  原始图像形状: (4536, 248, 248, 3)
  R通道图像形状: torch.Size([4536, 1, 248, 248])
  标签形状: torch.Size([4536])
顺序分割: 训练集 10584 样本, 验证集 4536 样本


Split method (not random split is recommended)

## models

In [26]:
# 直接在notebook中定义StockCNN模型类
import torch.nn as nn

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.)
    elif isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)

class StockCNN(nn.Module):
    """
    股票图像CNN模型
    专为248x248单通道图像设计
    """
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(5,3), stride=(3,1), dilation=(2,1), padding=(12,1)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.MaxPool2d((2, 1), stride=(2, 1)),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(5,3), stride=(3,1), dilation=(2,1), padding=(12,1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.MaxPool2d((2, 1), stride=(2, 1)),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(5,3), stride=(3,1), dilation=(2,1), padding=(12,1)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.MaxPool2d((2, 1), stride=(2, 1)),
        )
        # 根据248x248输入计算的全连接层输入尺寸
        self.fc1 = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(253952, 2),  # 二分类输出
        )
        self.softmax = nn.Softmax(dim=1)
       
    def forward(self, x):
        # 输入格式: (batch_size, 1, 248, 248)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.reshape(x.size(0), -1)  # 动态计算flatten后的尺寸
        x = self.fc1(x)
        #x = self.softmax(x)
        return x

In [27]:
device = 'cuda' if use_gpu else 'cpu'
export_onnx = True

# 使用新的StockCNN模型，专为248x248单通道图像设计
print(f"使用StockCNN模型（专为248x248单通道图像设计）")

# 创建模型
net = StockCNN().to(device)
net.apply(init_weights)

if export_onnx:
    import torch.onnx
    # 更新输入尺寸为新的格式 (batch_size, channels, height, width)
    x = torch.randn([1, 1, IMAGE_HEIGHT, IMAGE_WIDTH]).to(device)
    torch.onnx.export(net,               # model being run
                      x,                         # model input (or a tuple for multiple inputs)
                      "../cnn_baseline.onnx",   # where to save the model (can be a file or file-like object)
                      export_params=False,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=False,  # whether to execute constant folding for optimization
                      input_names = ['input_images'],   # the model's input names
                      output_names = ['output_prob'], # the model's output names
                      dynamic_axes={'input_images' : {0 : 'batch_size'},    # variable length axes
                                     'output_prob' : {0 : 'batch_size'}})


使用StockCNN模型（专为248x248单通道图像设计）


### Profiling

## train

In [28]:
def train_loop(dataloader, net, loss_fn, optimizer):
    
    running_loss = 0.0
    current = 0
    net.train()
    
    with tqdm(dataloader) as t:
        for batch, (X, y) in enumerate(t):
            X = X.to(device)
            y = y.to(device)
            y_pred = net(X)
            loss = loss_fn(y_pred, y.long())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss = (len(X) * loss.item() + running_loss * current) / (len(X) + current)
            current += len(X)
            t.set_postfix({'running_loss':running_loss})
    
    return running_loss

In [29]:
def val_loop(dataloader, net, loss_fn):

    running_loss = 0.0
    current = 0
    net.eval()
    
    with torch.no_grad():
        with tqdm(dataloader) as t:
            for batch, (X, y) in enumerate(t):
                X = X.to(device)
                y = y.to(device)
                y_pred = net(X)
                loss = loss_fn(y_pred, y.long())

                running_loss += loss.item()
                running_loss = (len(X) * running_loss + loss.item() * current) / (len(X) + current)
                current += len(X)
            
    return running_loss

In [None]:
# net = torch.load('/home/clidg/proj_2/pt/baseline_epoch_10_train_0.6865865240322523_eval_0.686580_.pt')

In [30]:
# 检查GPU可用性并设置DataParallel
use_gpu = True
if use_gpu:
    # 检查CUDA是否可用
    if not torch.cuda.is_available():
        print("警告: CUDA不可用，将使用CPU")
        use_gpu = False
        device = 'cpu'
        net = net.to(device)
    else:
        # 检查可用GPU数量
        available_gpus = torch.cuda.device_count()
        print(f"检测到 {available_gpus} 个GPU")
        
        if available_gpus == 0:
            print("警告: 没有可用的GPU，将使用CPU")
            use_gpu = False
            device = 'cpu'
            net = net.to(device)
        elif available_gpus == 1:
            print("只有一个GPU可用，不使用DataParallel")
            net = net.to(device)
        else:
            print(f"使用DataParallel，GPU数量: {available_gpus}")
            net = net.to(device)
            net = nn.DataParallel(net)
elif use_gpu:
    net = net.to(device)

检测到 1 个GPU
只有一个GPU可用，不使用DataParallel


In [31]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-5)

start_epoch = 0
min_val_loss = 1e9
last_min_ind = -1
early_stopping_epoch = 5

from torch.utils.tensorboard import SummaryWriter
tb = SummaryWriter()

In [32]:
start_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
os.mkdir('..\\pt'+os.sep+start_time)
epochs = 100
for t in range(start_epoch, epochs):
    print(f"Epoch {t}\n-------------------------------")
    time.sleep(0.2)
    train_loss = train_loop(train_dataloader, net, loss_fn, optimizer)
    val_loss = val_loop(val_dataloader, net, loss_fn)
    tb.add_histogram("train_loss", train_loss, t)
    torch.save(net, '../pt'+os.sep+start_time+os.sep+'baseline_epoch_{}_train_{:5f}_val_{:5f}.pt'.format(t, train_loss, val_loss)) 
    if val_loss < min_val_loss:
        last_min_ind = t
        min_val_loss = val_loss
    elif t - last_min_ind >= early_stopping_epoch:
        break

print('Done!')
print('Best epoch: {}, val_loss: {}'.format(last_min_ind, min_val_loss))

Epoch 0
-------------------------------


100%|██████████| 83/83 [00:26<00:00,  3.15it/s, running_loss=1.18]
100%|██████████| 18/18 [00:03<00:00,  4.80it/s]


Epoch 1
-------------------------------


100%|██████████| 83/83 [00:24<00:00,  3.40it/s, running_loss=1.09]
100%|██████████| 18/18 [00:03<00:00,  5.05it/s]


Epoch 2
-------------------------------


100%|██████████| 83/83 [00:24<00:00,  3.35it/s, running_loss=1.02]
100%|██████████| 18/18 [00:03<00:00,  4.99it/s]


Epoch 3
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.24it/s, running_loss=0.989]
100%|██████████| 18/18 [00:03<00:00,  4.89it/s]


Epoch 4
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.23it/s, running_loss=0.929]
100%|██████████| 18/18 [00:03<00:00,  4.93it/s]


Epoch 5
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.22it/s, running_loss=0.906]
100%|██████████| 18/18 [00:03<00:00,  4.68it/s]


Epoch 6
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.26it/s, running_loss=0.861]
100%|██████████| 18/18 [00:03<00:00,  4.94it/s]


Epoch 7
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.25it/s, running_loss=0.836]
100%|██████████| 18/18 [00:03<00:00,  4.91it/s]


Epoch 8
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.25it/s, running_loss=0.82] 
100%|██████████| 18/18 [00:03<00:00,  4.84it/s]


Epoch 9
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.22it/s, running_loss=0.774]
100%|██████████| 18/18 [00:03<00:00,  4.91it/s]


Epoch 10
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.25it/s, running_loss=0.752]
100%|██████████| 18/18 [00:03<00:00,  4.92it/s]


Epoch 11
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.23it/s, running_loss=0.757]
100%|██████████| 18/18 [00:03<00:00,  4.85it/s]


Epoch 12
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.25it/s, running_loss=0.717]
100%|██████████| 18/18 [00:03<00:00,  4.85it/s]


Epoch 13
-------------------------------


100%|██████████| 83/83 [00:25<00:00,  3.24it/s, running_loss=0.691]
100%|██████████| 18/18 [00:03<00:00,  4.87it/s]

Done!
Best epoch: 8, val_loss: 0.6811516505452196



