In [1]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.modeling import models, fitting
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm import tqdm
import copy

import sys
sys.path.append('../../wuchengzhou')
import sagan

wave_dict = sagan.utils.line_wave_dict
label_dict = sagan.utils.line_label_dict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
uniform = np.random.uniform
normal = np.random.normal

def pnormal(mean, stddev):
    while True:
        value = normal(mean, stddev)
        if value >= 0:  # 确保值不为负
            return value

arg_dict_func = {
            'b_ha': {'amp_c':uniform, 'sigma_c':uniform, 'dv_c':normal, 'amp_w0':uniform, 'dv_w0':normal, 'sigma_w0':pnormal}, 
            'b_hb': {'amp_c':uniform, 'sigma_c':pnormal, 'dv_c':normal, 'amp_w0':uniform, 'dv_w0':normal, 'sigma_w0':pnormal}, 
            'b_hg': {'amp_c':uniform, 'sigma_c':pnormal, 'dv_c':normal}, 
            'n_ha':{'amp_c':pnormal}, 
            'n_hb':{'amp_c':pnormal}, 
            'n_hc':{'amp_c':pnormal}, 
            'line_o3': {'amp_c0':pnormal, 'sigma_c':pnormal, 'dv_c':normal, 'amp_w0':uniform, 'dv_w0':normal, 'sigma_w0':pnormal}, 
            'b_HeI': {'amp_c':pnormal, 'sigma_c':uniform, 'dv_c':normal}
        }
        
arg_dict_range = {
    'b_ha': {'amp_c':(1.5, 2.5), 'sigma_c':(1200, 1600), 'dv_c':(0, 75), 'amp_w0':(0.05, 0.6), 'dv_w0':(0, 400), 'sigma_w0':(5000, 400)}, 
    'b_hb': {'amp_c':(0.7, 1.7), 'sigma_c':(1500, 200), 'dv_c':(0, 75), 'amp_w0':(0.05, 0.3), 'dv_w0':(0, 100), 'sigma_w0':(5000, 450)}, 
    'b_hg': {'amp_c':(0.4, 0.9), 'sigma_c':(1500, 200), 'dv_c':(0, 75)}, 
    'n_ha':{'amp_c':(0.1, 0.05)}, 
    'n_hb':{'amp_c':(0.1, 0.05)}, 
    'n_hc':{'amp_c':(0.1, 0.05)}, 
    'line_o3': {'amp_c0':(1, 0.5), 'sigma_c':(500, 200), 'dv_c':(0, 75), 'amp_w0':(0.1, 0.5), 'dv_w0':(-100, 100), 'sigma_w0':(1700, 400)}, 
    'b_HeI': {'amp_c':(0.1, 0.08), 'sigma_c':(1400, 1800), 'dv_c':(0, 75)}
}

In [3]:
def generate_continuum(wave):
    # Generate random parameters for the power law
    amp1 = 10 * np.random.rand()
    amp2 = np.random.rand()
    alpha = uniform(0, 2)
    stddev = uniform(500, 2500)
    z = uniform(0, 0.01)
    
    # Create the model
    pl_amps = models.PowerLaw1D(amplitude=amp1, x_0=5500, alpha=alpha, fixed={'x_0': True})
    iron = sagan.IronTemplate(amplitude=amp2, stddev=stddev, z=z, name='Fe II')
    model = pl_amps + iron
    flux = model(wave)
    
    # Add noise
    noise = np.random.normal(0, 0.1, wave.size)
    flux += noise
    
    return flux

# narrow Line with 2 components
# Hb:2, oIII:2, narrow: 1, Ha: 2
def generate_spec(wave, arg_dict):

    amp_c0 = arg_dict['line_o3']['amp_c0']
    dv_c = arg_dict['line_o3']['dv_c']
    sigma_c = arg_dict['line_o3']['sigma_c']
    amp_w0 = arg_dict['line_o3']['amp_w0']
    dv_w0 = arg_dict['line_o3']['dv_w0']
    sigma_w0 = arg_dict['line_o3']['sigma_w0']

    line_o3 = sagan.Line_MultiGauss_doublet(n_components=2, amp_c0=amp_c0, amp_c1=0.2, dv_c=dv_c, sigma_c=sigma_c, wavec0=wave_dict['OIII_5007'], wavec1=wave_dict['OIII_4959'], name='[O III]', amp_w0=amp_w0, dv_w0=dv_w0, sigma_w0=sigma_w0)
    
    def tie_o3(model):
        return model['[O III]'].amp_c0 / 2.98
    line_o3.amp_c1.tied = tie_o3
    
    n_ha = sagan.Line_MultiGauss(n_components=1, amp_c=arg_dict['n_ha']['amp_c'], wavec=wave_dict['Halpha'], name=f'narrow {label_dict["Halpha"]}')
    n_hb = sagan.Line_MultiGauss(n_components=1, amp_c=arg_dict['n_hb']['amp_c'], wavec=wave_dict['Hbeta'], name=f'narrow {label_dict["Hbeta"]}')
    n_hg = sagan.Line_MultiGauss(n_components=1, amp_c=arg_dict['n_hc']['amp_c'], wavec=wave_dict['Hgamma'], name=f'narrow {label_dict["Hgamma"]}')

    
    b_HeI = sagan.Line_MultiGauss(n_components=1, amp_c=arg_dict['b_HeI']['amp_c'], dv_c=arg_dict['b_HeI']['dv_c'], sigma_c=arg_dict['b_HeI']['sigma_c'], wavec=5875.624, name=f'He I 5876')
    
    b_ha = sagan.Line_MultiGauss(n_components=2, amp_c=arg_dict['b_ha']['amp_c'], dv_c=arg_dict['b_ha']['dv_c'], sigma_c=arg_dict['b_ha']['sigma_c'], wavec=wave_dict['Halpha'], name=label_dict['Halpha'], amp_w0=arg_dict['b_ha']['amp_w0'], sigma_w0=arg_dict['b_ha']['sigma_w0'], dv_w0=arg_dict['b_ha']['dv_w0'])
    b_hb = sagan.Line_MultiGauss(n_components=2, amp_c=arg_dict['b_hb']['amp_c'], dv_c=arg_dict['b_hb']['dv_c'], sigma_c=arg_dict['b_hb']['sigma_c'], wavec=wave_dict['Hbeta'], name=label_dict['Hbeta'], amp_w0=arg_dict['b_hb']['amp_w0'], dv_w0=arg_dict['b_hb']['dv_w0'], sigma_w0=arg_dict['b_hb']['sigma_w0'])
    b_hg = sagan.Line_MultiGauss(n_components=1, amp_c=arg_dict['b_hg']['amp_c'], dv_c=arg_dict['b_hg']['dv_c'], sigma_c=arg_dict['b_hg']['sigma_c'], wavec=wave_dict['Hgamma'], name=label_dict['Hgamma'])
    
    def tie_narrow_sigma_c(model):
        return model['[O III]'].sigma_c

    def tie_narrow_dv_c(model):
        return model['[O III]'].dv_c

    for line in [n_ha, n_hb, n_hg]:
        line.sigma_c.tied = tie_narrow_sigma_c
        line.dv_c.tied = tie_narrow_dv_c
    
    line_ha = b_ha + n_ha
    line_hb = b_hb + n_hb
    line_hg = b_hg + n_hg

    # def model
    model = (line_ha + line_hb + line_hg + line_o3 + b_HeI)
    
    # Add Gaussian noise
    noise = np.random.normal(0, 0.015, wave.size)
    
    flux = model(wave) + noise
    
    return flux

In [4]:
def generate_data(arg_dict_func, arg_dict_range, num_samples=200, input_width=1000):
    X_list = []
    y_list = []
    
    for _ in range(num_samples):
        
        # arg_dict_func = {
        #     'b_ha': {'amp_c':uniform, 'sigma_c':uniform, 'dv_c':normal, 'amp_w0':uniform, 'dv_w0':normal, 'sigma_w0':pnormal}, 
        #     'b_hb': {'amp_c':uniform, 'sigma_c':pnormal, 'dv_c':normal, 'amp_w0':uniform, 'dv_w0':normal, 'sigma_w0':pnormal}, 
        #     'b_hg': {'amp_c':uniform, 'sigma_c':pnormal, 'dv_c':normal}, 
        #     'n_ha':{'amp_c':pnormal}, 
        #     'n_hb':{'amp_c':pnormal}, 
        #     'n_hc':{'amp_c':pnormal}, 
        #     'line_o3': {'amp_c0':pnormal, 'sigma_c':pnormal, 'dv_c':normal, 'amp_w0':uniform, 'dv_w0':normal, 'sigma_w0':pnormal}, 
        #     'b_HeI': {'amp_c':pnormal, 'sigma_c':uniform, 'dv_c':normal}
        # }
        
        # arg_dict_range = {
        #     'b_ha': {'amp_c':(1.5, 2.5), 'sigma_c':(1200, 1600), 'dv_c':(0, 75), 'amp_w0':(0.05, 0.6), 'dv_w0':(0, 400), 'sigma_w0':(5000, 400)}, 
        #     'b_hb': {'amp_c':(0.7, 1.7), 'sigma_c':(1500, 200), 'dv_c':(0, 75), 'amp_w0':(0.05, 0.3), 'dv_w0':(0, 100), 'sigma_w0':(5000, 450)}, 
        #     'b_hg': {'amp_c':(0.4, 0.9), 'sigma_c':(1500, 200), 'dv_c':(0, 75)}, 
        #     'n_ha':{'amp_c':(0.1, 0.05)}, 
        #     'n_hb':{'amp_c':(0.1, 0.05)}, 
        #     'n_hc':{'amp_c':(0.1, 0.05)}, 
        #     'line_o3': {'amp_c0':(1, 0.5), 'sigma_c':(500, 200), 'dv_c':(0, 75), 'amp_w0':(0.1, 0.5), 'dv_w0':(-100, 100), 'sigma_w0':(1700, 400)}, 
        #     'b_HeI': {'amp_c':(0.1, 0.08), 'sigma_c':(1400, 1800), 'dv_c':(0, 75)}
        # }
        
        arg_dict = {key: {param: arg_dict_func[key][param](*arg_dict_range[key][param]) for param in arg_dict_func[key]} for key in arg_dict_func}
        
        wave = np.linspace(4150, 7000, input_width)
        flux = generate_spec(wave, arg_dict=arg_dict)
        
        data = np.stack((wave, flux), axis=0)
        X_list.append(torch.tensor(data, dtype=torch.float32).view(1, 2, input_width))
        arg_list = [value for line in arg_dict.values() for value in line.values()]
        y_list.append(torch.tensor(arg_list, dtype=torch.float32))
    
    X = torch.cat(X_list, dim=0).reshape(num_samples, 1, 2, input_width)
    y = torch.stack(y_list)

    return X, y

In [5]:
class ImprovedCNN(nn.Module):
    def __init__(self, input_height=2, input_width=1000, output_dim=27):
        super(ImprovedCNN, self).__init__()
        # 初始标准化
        self.bn0 = nn.BatchNorm2d(1)
        
        # 第一个卷积块：添加残差连接
        self.conv1 = nn.Conv2d(1, 32, kernel_size=(1, 21), padding=(0, 10))
        self.bn1 = nn.BatchNorm2d(32)
        self.conv1_res = nn.Conv2d(1, 32, kernel_size=(1, 1))  # 1x1卷积用于维度匹配
        
        # 第二个卷积块
        self.conv2 = nn.Conv2d(32, 64, kernel_size=(1, 21), padding=(0, 10))
        self.bn2 = nn.BatchNorm2d(64)
        self.conv2_res = nn.Conv2d(32, 64, kernel_size=(1, 1))
        
        # 第三个卷积块：使用空洞卷积增加感受野
        self.conv3 = nn.Conv2d(64, 128, kernel_size=(1, 21), padding=(0, 20), dilation=(1, 2))
        self.bn3 = nn.BatchNorm2d(128)
        self.conv3_res = nn.Conv2d(64, 128, kernel_size=(1, 1))
        
        # 使用注意力机制
        self.se_block = SEBlock(128)
        
        # 池化层
        self.pool = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        
        # 全连接层
        self.fc_input_dim = self._calculate_fc_input_dim(input_height, input_width)
        self.dropout1 = nn.Dropout(0.3)
        self.fc1 = nn.Linear(self.fc_input_dim, 256)
        self.fc_bn1 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, 128)
        self.fc_bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, output_dim)
        
    def _calculate_fc_input_dim(self, input_height, input_width):
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, input_height, input_width)
            x = self.pool(self._conv_block3(self.pool(self._conv_block2(self.pool(self._conv_block1(self.bn0(dummy_input)))))))
            return x.numel()
    
    def _conv_block1(self, x):
        residual = self.conv1_res(x)
        x = F.leaky_relu(self.bn1(self.conv1(x)))
        return x + residual
    
    def _conv_block2(self, x):
        residual = self.conv2_res(x)
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        return x + residual
    
    def _conv_block3(self, x):
        residual = self.conv3_res(x)
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = self.se_block(x)
        return x + residual
    
    def forward(self, x):
        x = self.bn0(x)
        
        # 卷积模块
        x = self._conv_block1(x)
        x = self.pool(x)
        
        x = self._conv_block2(x)
        x = self.pool(x)
        
        x = self._conv_block3(x)
        x = self.pool(x)
        
        # 全连接模块
        x = x.view(x.size(0), -1)
        x = self.dropout1(x)
        x = F.leaky_relu(self.fc_bn1(self.fc1(x)))
        
        x = self.dropout2(x)
        x = F.leaky_relu(self.fc_bn2(self.fc2(x)))
        
        x = self.fc3(x)
        return x

# 添加注意力机制
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

In [6]:
class ImprovedLoss(nn.Module):
    def __init__(self, arg_dict_func, arg_dict_range, alpha=0.1, beta=0.1):
        super(ImprovedLoss, self).__init__()
        self.arg_dict_func = arg_dict_func
        self.arg_dict_range = arg_dict_range
        
        # 根据参数范围设置权重
        self.w = []
        self.param_indices = {}  # 跟踪参数位置
        
        idx = 0
        for key1, line in arg_dict_func.items():
            self.param_indices[key1] = {}
            for key2, value in line.items():
                self.param_indices[key1][key2] = idx
                
                if value == uniform:
                    self.w.append(arg_dict_range[key1][key2][1] - arg_dict_range[key1][key2][0])
                elif value == pnormal or value == normal:
                    self.w.append(arg_dict_range[key1][key2][1])
                idx += 1
                
        self.w = torch.tensor(self.w, dtype=torch.float32).to(device)
        
        # 针对不同参数设置权重
        self.importance_weights = torch.ones_like(self.w)
        
        # 对重要参数增加权重（示例：增加Ha和Hb的权重）
        for key1 in ['b_ha', 'b_hb', 'line_o3']:
            for key2 in arg_dict_func[key1]:
                idx = self.param_indices[key1][key2]
                self.importance_weights[idx] = 2.0
        
        self.alpha = alpha  # L1正则化参数
        self.beta = beta    # 平滑项参数

    def normalize(self, x):
        return x / self.w

    def forward(self, outputs, targets):
        # 归一化目标值
        targets_norm = self.normalize(targets)
        
        # 使用Huber损失结合权重
        diff = outputs - targets_norm
        mse_loss = torch.pow(diff, 2)
        l1_loss = torch.abs(diff)
        
        # 使用平滑的L1损失
        smooth_l1 = torch.where(l1_loss < 1.0, 
                               0.5 * mse_loss,
                               l1_loss - 0.5)
        
        # 添加权重
        weighted_loss = smooth_l1 * self.importance_weights
        
        # 对不同样本添加不同权重
        # 对误差大的样本增加权重
        batch_weights = 1.0 + torch.mean(l1_loss, dim=1) * 0.5
        
        # 添加正则化项，鼓励预测值在合理范围内
        l1_reg = torch.mean(torch.abs(outputs))
        
        # 添加平滑项，惩罚输出的不连续性
        smoothness = torch.mean(torch.abs(outputs[:, 1:] - outputs[:, :-1]))
        
        # 总损失
        loss = (torch.mean(weighted_loss * batch_weights.unsqueeze(1)) + 
                self.alpha * l1_reg + 
                self.beta * smoothness)
        
        return loss

In [7]:
def calculate_accuracy(model, dataloader):
    loss = ImprovedLoss(arg_dict_func, arg_dict_range)
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            outputs = model(inputs)
            # 计算相对误差
            relative_error = torch.abs(outputs * loss.w - targets) / torch.abs(targets)
            correct = torch.all(relative_error < 0.5, dim=1).sum().item()
            total_correct += correct
            total_samples += inputs.size(0)
    accuracy = total_correct / total_samples
    return accuracy

def plot_training_curve(train_losses, train_accuracies, test_accuracies):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss Curve")
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label="Train Accuracy")
    plt.plot(test_accuracies, label="Test Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Accuracy Curve")
    plt.legend()
    plt.show()

In [8]:
# 改进训练过程
def train_improved_model(model, train_loader, test_loader, criterion, num_epochs=400):
    # 使用AdamW优化器代替SGD
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)
    
    # 学习率调度器：余弦退火
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Warmup调度器
    warmup_epochs = 10
    warmup_scheduler = optim.lr_scheduler.LinearLR(
        optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
    )
    
    train_losses = []
    train_accuracies = []
    test_accuracies = []
    
    # 梯度累积步数
    accumulation_steps = 4
    
    # EMA模型（指数移动平均）
    ema_model = copy.deepcopy(model)
    ema_decay = 0.999
    
    # 早停机制
    patience = 30
    best_loss = float('inf')
    counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        optimizer.zero_grad()
        
        # 使用进度条
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for i, (inputs, targets) in enumerate(pbar):
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                # 前向传播
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                # 缩放损失以适应梯度累积
                loss = loss / accumulation_steps
                
                # 反向传播
                loss.backward()
                
                # 梯度累积
                if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader):
                    # 梯度裁剪
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    
                    # 更新参数
                    optimizer.step()
                    optimizer.zero_grad()
                    
                    # 更新EMA模型
                    with torch.no_grad():
                        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                            ema_param.data = (ema_decay * ema_param.data + 
                                            (1 - ema_decay) * param.data)
                
                epoch_loss += loss.item() * accumulation_steps
                pbar.set_postfix({'loss': loss.item() * accumulation_steps})
            
        # 更新学习率
        if epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            scheduler.step()
        
        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        # 计算训练集和测试集上的准确率
        model.eval()
        train_accuracy = calculate_accuracy(model, train_loader)
        
        # 使用EMA模型进行评估
        ema_model.eval()
        test_accuracy = calculate_accuracy(ema_model, test_loader)
        
        train_accuracies.append(train_accuracy)
        test_accuracies.append(test_accuracy)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, "
                 f"Train Acc: {train_accuracy:.4f}, Test Acc: {test_accuracy:.4f}, "
                 f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # 早停检查
        if avg_loss < best_loss:
            best_loss = avg_loss
            # 保存最佳模型
            torch.save(ema_model.state_dict(), "./model/best_model.pth")
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    return train_losses, train_accuracies, test_accuracies



In [9]:
# 使用改进后的方法训练模型
if __name__ == "__main__":
    
    # 超参数
    input_height = 2
    input_width = 1000
    output_dim = 27
    num_samples = 114514
    batch_size = 32
    num_epochs = 200
    learning_rate = 0.001

    try:
        X = torch.load("./data_generated/X.pt")
        y = torch.load("./data_generated/y.pt")
        print("数据已成功加载。")
    except FileNotFoundError:
        print("未找到数据，正在生成数据...")
        X, y = generate_data(arg_dict_func, arg_dict_range, num_samples, input_width)
        # 保存数据
        torch.save(X, "./data_generated/X.pt")
        torch.save(y, "./data_generated/y.pt")
        print("数据已成功生成并保存。")
    
    dataset = TensorDataset(X.to(device), y.to(device))
    train_size = int(0.8 * num_samples)
    test_size = num_samples - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    
    # 初始化改进的模型和损失函数
    model = ImprovedCNN(input_height, input_width, output_dim).to(device)
    criterion = ImprovedLoss(arg_dict_func, arg_dict_range)
    
    # 训练改进的模型
    train_losses, train_accuracies, test_accuracies = train_improved_model(
        model, train_loader, test_loader, criterion, num_epochs=num_epochs
    )
    
    # 绘制训练过程
    plot_training_curve(train_losses, train_accuracies, test_accuracies)
    
    # 保存模型
    model_name = 'cnn2'
    torch.save(model.state_dict(), f"./model/{model_name}.pth")
    print(f"Model saved to ./model/{model_name}.pth")
    
    # 加载模型并测试
    model.load_state_dict(torch.load(f"./model/{model_name}.pth"))
    test_accuracy = calculate_accuracy(model, test_loader)

数据已成功加载。


Epoch 1/200: 100%|██████████| 2863/2863 [00:38<00:00, 74.68it/s, loss=3.35]
Epoch 2/200: 100%|██████████| 2863/2863 [00:37<00:00, 75.77it/s, loss=1.34]
Epoch 3/200: 100%|██████████| 2863/2863 [00:39<00:00, 72.19it/s, loss=1.32]
Epoch 4/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.19it/s, loss=1.4] 
Epoch 5/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.88it/s, loss=1.31]
Epoch 6/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.19it/s, loss=1.31]
Epoch 7/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.65it/s, loss=1.34]
Epoch 8/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.64it/s, loss=1.25]
Epoch 9/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.35it/s, loss=1.31]
Epoch 10/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.87it/s, loss=1.3] 
Epoch 11/200:   0%|          | 7/2863 [00:00<00:42, 67.02it/s, loss=1.34]

Epoch [10/200], Loss: 1.3089, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.001000


Epoch 11/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.47it/s, loss=1.29]
Epoch 12/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.26it/s, loss=1.31]
Epoch 13/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.37it/s, loss=1.31]
Epoch 14/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.47it/s, loss=1.29]
Epoch 15/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.00it/s, loss=1.31]
Epoch 16/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.15it/s, loss=1.26]
Epoch 17/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.82it/s, loss=1.37]
Epoch 18/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.42it/s, loss=1.31]
Epoch 19/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.42it/s, loss=1.24]
Epoch 20/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.35it/s, loss=1.3] 
Epoch 21/200:   0%|          | 0/2863 [00:00<?, ?it/s, loss=1.31]

Epoch [20/200], Loss: 1.3020, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000994


Epoch 21/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.61it/s, loss=1.24]
Epoch 22/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.40it/s, loss=1.36]
Epoch 23/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.21it/s, loss=1.27]
Epoch 24/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.12it/s, loss=1.3] 
Epoch 25/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.70it/s, loss=1.35]
Epoch 26/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.85it/s, loss=1.28]
Epoch 27/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.69it/s, loss=1.31]
Epoch 28/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.61it/s, loss=1.33]
Epoch 29/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.06it/s, loss=1.27]
Epoch 30/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.53it/s, loss=1.28]
Epoch 31/200:   0%|          | 8/2863 [00:00<00:39, 71.96it/s, loss=1.35]

Epoch [30/200], Loss: 1.3014, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000976


Epoch 31/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.20it/s, loss=1.31]
Epoch 32/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.75it/s, loss=1.28]
Epoch 33/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.86it/s, loss=1.26]
Epoch 34/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.67it/s, loss=1.29]
Epoch 35/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.87it/s, loss=1.34]
Epoch 36/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.88it/s, loss=1.34]
Epoch 37/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.85it/s, loss=1.34]
Epoch 38/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.39it/s, loss=1.33]
Epoch 39/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.27it/s, loss=1.27]
Epoch 40/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.65it/s, loss=1.31]
Epoch 41/200:   0%|          | 7/2863 [00:00<00:40, 69.69it/s, loss=1.33]

Epoch [40/200], Loss: 1.3011, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000946


Epoch 41/200: 100%|██████████| 2863/2863 [00:45<00:00, 63.14it/s, loss=1.28]
Epoch 42/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.28it/s, loss=1.28]
Epoch 43/200: 100%|██████████| 2863/2863 [00:46<00:00, 62.06it/s, loss=1.34]
Epoch 44/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.39it/s, loss=1.3] 
Epoch 45/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.36it/s, loss=1.32]
Epoch 46/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.89it/s, loss=1.29]
Epoch 47/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.36it/s, loss=1.34]
Epoch 48/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.66it/s, loss=1.33]
Epoch 49/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.26it/s, loss=1.27]
Epoch 50/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.45it/s, loss=1.26]
Epoch 51/200:   0%|          | 7/2863 [00:00<00:42, 66.49it/s, loss=1.32]

Epoch [50/200], Loss: 1.3008, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000905


Epoch 51/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.30it/s, loss=1.37]
Epoch 52/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.57it/s, loss=1.35]
Epoch 53/200: 100%|██████████| 2863/2863 [00:46<00:00, 62.03it/s, loss=1.28]
Epoch 54/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.54it/s, loss=1.31]
Epoch 55/200: 100%|██████████| 2863/2863 [00:48<00:00, 59.56it/s, loss=1.32]
Epoch 56/200: 100%|██████████| 2863/2863 [00:47<00:00, 60.51it/s, loss=1.32]
Epoch 57/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.64it/s, loss=1.32]
Epoch 58/200: 100%|██████████| 2863/2863 [00:45<00:00, 62.37it/s, loss=1.27]
Epoch 59/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.64it/s, loss=1.31]
Epoch 60/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.76it/s, loss=1.22]
Epoch 61/200:   0%|          | 8/2863 [00:00<00:41, 69.07it/s, loss=1.23]

Epoch [60/200], Loss: 1.3006, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000854


Epoch 61/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.30it/s, loss=1.32]
Epoch 62/200: 100%|██████████| 2863/2863 [00:46<00:00, 61.56it/s, loss=1.26]
Epoch 63/200: 100%|██████████| 2863/2863 [00:37<00:00, 76.08it/s, loss=1.28]
Epoch 64/200: 100%|██████████| 2863/2863 [00:37<00:00, 75.67it/s, loss=1.33]
Epoch 65/200: 100%|██████████| 2863/2863 [00:38<00:00, 74.72it/s, loss=1.3] 
Epoch 66/200: 100%|██████████| 2863/2863 [01:09<00:00, 41.19it/s, loss=1.33]
Epoch 67/200: 100%|██████████| 2863/2863 [00:43<00:00, 66.45it/s, loss=1.28]
Epoch 68/200: 100%|██████████| 2863/2863 [01:01<00:00, 46.48it/s, loss=1.25]
Epoch 69/200: 100%|██████████| 2863/2863 [07:01<00:00,  6.79it/s, loss=1.28]  
Epoch 70/200: 100%|██████████| 2863/2863 [00:40<00:00, 70.77it/s, loss=1.33]
Epoch 71/200:   0%|          | 7/2863 [00:00<00:41, 68.87it/s, loss=1.26]

Epoch [70/200], Loss: 1.3005, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000794


Epoch 71/200: 100%|██████████| 2863/2863 [00:40<00:00, 70.22it/s, loss=1.31]
Epoch 72/200: 100%|██████████| 2863/2863 [00:42<00:00, 67.23it/s, loss=1.28]
Epoch 73/200: 100%|██████████| 2863/2863 [00:41<00:00, 68.90it/s, loss=1.29]
Epoch 74/200: 100%|██████████| 2863/2863 [00:41<00:00, 68.36it/s, loss=1.3] 
Epoch 75/200: 100%|██████████| 2863/2863 [00:40<00:00, 71.15it/s, loss=1.29]
Epoch 76/200: 100%|██████████| 2863/2863 [03:02<00:00, 15.69it/s, loss=1.36]
Epoch 77/200: 100%|██████████| 2863/2863 [03:12<00:00, 14.85it/s, loss=1.26] 
Epoch 78/200: 100%|██████████| 2863/2863 [00:40<00:00, 70.04it/s, loss=1.33]
Epoch 79/200: 100%|██████████| 2863/2863 [00:41<00:00, 68.82it/s, loss=1.36]
Epoch 80/200: 100%|██████████| 2863/2863 [00:41<00:00, 68.84it/s, loss=1.27]
Epoch 81/200:   0%|          | 7/2863 [00:00<00:41, 68.44it/s, loss=1.35]

Epoch [80/200], Loss: 1.3004, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000727


Epoch 81/200: 100%|██████████| 2863/2863 [00:42<00:00, 67.01it/s, loss=1.34]
Epoch 82/200: 100%|██████████| 2863/2863 [00:41<00:00, 69.19it/s, loss=1.36]
Epoch 83/200: 100%|██████████| 2863/2863 [00:41<00:00, 69.54it/s, loss=1.3] 
Epoch 84/200: 100%|██████████| 2863/2863 [00:42<00:00, 66.81it/s, loss=1.28]
Epoch 85/200: 100%|██████████| 2863/2863 [00:41<00:00, 69.26it/s, loss=1.31]
Epoch 86/200: 100%|██████████| 2863/2863 [00:43<00:00, 65.73it/s, loss=1.34]
Epoch 87/200: 100%|██████████| 2863/2863 [00:41<00:00, 69.54it/s, loss=1.35]
Epoch 88/200: 100%|██████████| 2863/2863 [00:43<00:00, 66.46it/s, loss=1.3] 
Epoch 89/200: 100%|██████████| 2863/2863 [00:41<00:00, 69.01it/s, loss=1.31]
Epoch 90/200: 100%|██████████| 2863/2863 [00:41<00:00, 68.67it/s, loss=1.3] 
Epoch 91/200:   0%|          | 8/2863 [00:00<00:45, 63.05it/s, loss=1.28]

Epoch [90/200], Loss: 1.3002, Train Acc: 0.0000, Test Acc: 0.0000, LR: 0.000655


Epoch 91/200: 100%|██████████| 2863/2863 [00:43<00:00, 65.67it/s, loss=1.25]
Epoch 92/200: 100%|██████████| 2863/2863 [00:42<00:00, 67.55it/s, loss=1.27]
Epoch 93/200:  19%|█▉        | 540/2863 [00:07<00:34, 67.58it/s, loss=1.3] 


KeyboardInterrupt: 