## 这个notebook意在和论文训练方法对齐

In [1]:
pwd

'/home/jwangiy/Reimage/my_reimagine/reimagine/notebooks'

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

pd.set_option("display.max_columns", 200)
pd.set_option("display.max_rows", 200)

In [3]:
import os.path as op

IMAGE_WIDTH = {5: 15, 20: 60, 60: 180}
IMAGE_HEIGHT = {5: 32, 20: 64, 60: 96}  

#### project struture
[struture](https://drivendata.github.io/cookiecutter-data-science/#directory-structure)

### Build Models

In [4]:
target = 'Retx_20d_label'

In [5]:
import pandas as pd
import numpy as np
import os
from matplotlib import pyplot as plt
from tqdm import tqdm
import pickle

from sklearn.model_selection import train_test_split
from torch.utils import data 
from torch.autograd import Variable 
import torch
import torch.nn as nn
from torch.nn import init

import torch
from torchvision import datasets,transforms
import torch.utils.data as Dataset

In [6]:
train_val_years = list(range(1993, 1999+1))
test_years = list(range(1999+1, 2019+1))

images_train = []
for year in train_val_years:
    images_temp = np.memmap(op.join("/home/jwangiy/Reimage/img_data/monthly_20d", f"20d_month_has_vb_[20]_ma_{year}_images.dat"), dtype=np.uint8, mode='r').reshape(
                        (-1, IMAGE_HEIGHT[20], IMAGE_WIDTH[20]))
    images_train.append(images_temp)
images_train = np.concatenate(images_train)

images_test = []
for year in test_years:
    images_temp = np.memmap(op.join("/home/jwangiy/Reimage/img_data/monthly_20d", f"20d_month_has_vb_[20]_ma_{year}_images.dat"), dtype=np.uint8, mode='r').reshape(
                        (-1, IMAGE_HEIGHT[20], IMAGE_WIDTH[20]))
    images_test.append(images_temp)
images_test = np.concatenate(images_test)

In [7]:
del images_temp
images_train.shape, images_test.shape

((694871, 64, 60), (1502123, 64, 60))

In [8]:
target = 'Retx_20d_label'

In [9]:
label_train = []
for year in train_val_years:
    label_temp = pd.read_feather(op.join("/home/jwangiy/Reimage/img_data/monthly_20d", f"20d_month_has_vb_[20]_ma_{year}_labels_w_delay.feather"))
    label_train.append(label_temp)
    
label_train = pd.concat(label_train)[[target]]

label_test = []
for year in test_years:
    label_temp = pd.read_feather(op.join("/home/jwangiy/Reimage/img_data/monthly_20d", f"20d_month_has_vb_[20]_ma_{year}_labels_w_delay.feather"))
    label_test.append(label_temp)
    
label_test = pd.concat(label_test)[[target]]

In [10]:
del label_temp
label_train.shape, label_test.shape

((694871, 1), (1502123, 1))

In [11]:
label_test[target] = label_test[target].apply(lambda x: 0 if x == 0 else 1)
label_train[target] = label_train[target].apply(lambda x: 0 if x == 0 else 1)

In [12]:
from sklearn.model_selection import train_test_split

x_train, x_val, y_train, y_val = train_test_split(images_train, label_train, test_size=0.3, random_state=0, shuffle=True)

In [13]:
class CNNDataset(Dataset.Dataset):
    def __init__(self, Data, Label):
        self.Data = Data
        self.Label = Label
    def __len__(self):
        return len(self.Data)
    def __getitem__(self, index):
        data = torch.Tensor(self.Data[index])
        label = torch.Tensor(self.Label[index])
        
        return data, label

In [14]:
train_dataset = CNNDataset(x_train, y_train.values)
val_dataset = CNNDataset(x_val, y_val.values)
test_dataset = CNNDataset(images_test, label_test.values)

train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = data.DataLoader(val_dataset, batch_size=128, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=128, shuffle=False)

In [15]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print([torch.cuda.get_device_name(i) for i in range(0, torch.cuda.device_count())])

True
0
['NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090', 'NVIDIA GeForce RTX 3090']


In [16]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(5,3), stride=(3,1), dilation=(2,1), padding=(12,1)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.MaxPool2d((2, 1), stride=(2, 1)),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(5,3), stride=(3,1), dilation=(2,1), padding=(12,1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.MaxPool2d((2, 1), stride=(2, 1)),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(5,3), stride=(3,1), dilation=(2,1), padding=(12,1)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.MaxPool2d((2, 1), stride=(2, 1)),
        )
        self.fc1 = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(46080, 2),
        )

        # 使用 Xavier 初始化来初始化权重
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
       
    def forward(self, x):
        x = x.reshape(-1,1,64,60)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.reshape(-1,46080)
        x = self.fc1(x)
        return x

In [17]:
model = CNN()
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
device_ids=range(torch.cuda.device_count())
model.cuda(device=device_ids[3])

CNN(
  (layer1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(5, 3), stride=(3, 1), padding=(12, 1), dilation=(2, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(5, 3), stride=(3, 1), padding=(12, 1), dilation=(2, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(5, 3), stride=(3, 1), padding=(12, 1), dilation=(2, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01, inplace=True)
    (3): MaxPoo

In [18]:
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.00001)

# alpha, beta = 0.000004, 0.6
# lambda1 = lambda epoch: alpha/(1-beta**(epoch+1))
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1, last_epoch=-1)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=False) # 每5次不下降就用lr*factor(作为一个improvement)

In [19]:
try:
    os.mkdir("./cnn_model")
except Exception as e:
    pass

In [20]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience. Copy from pytorchtools"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        print(self.path)
    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [21]:
early_stopping = EarlyStopping(patience = 2,delta = 0.0001, path= './cnn_model/ret_20_classification_model_stat_checkpoint.pt',verbose=True)

./cnn_model/ret_20_classification_model_stat_checkpoint.pt


In [22]:
%%time

loss_count = []
epochs = 100
global_loss_train = []
global_loss_test = []
global_loss_val = []

lr = []

for epoch in range(epochs):
    
    lr.append(optimizer.param_groups[0]["lr"])
    
    for i,(x,y) in enumerate(train_loader):
        batch_x = Variable(x.cuda(device=device_ids[3]))
        batch_y = Variable(y.cuda(device=device_ids[3]))
        out = model(batch_x.float())
        loss = loss_func(out,batch_y.squeeze().long())
        optimizer.zero_grad()
        loss.backward() 
        optimizer.step()
        if i%20 == 0:
            temp = loss.cpu()
            loss_count.append(temp.detach().numpy())
            print('epoch:', format(epoch+1),f'iteration: {i+1}:\t','loss:', loss.item())
            torch.save(model,r'./cnn_model/ret_20_classification_model_checkpoint.pt')
            
    
    
    
    loss_val_epoch = [] 
    for x,y in val_loader:
        batch_x = Variable(x.cuda(device=device_ids[3]))
        batch_y = Variable(y.cuda(device=device_ids[3]))
        prediction = model(batch_x)
        loss = loss_func(prediction, batch_y.squeeze().long())
        loss_val_epoch.append(loss.cpu().detach().numpy())
        
    loss_val = np.mean(loss_val_epoch)
    global_loss_val.append(loss_val)

    loss_train = loss_count[-1]
    global_loss_train.append(loss_train)
    
    scheduler.step(loss_train) # 每次epoch后更新learning rate
    
    early_stopping(loss_val, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    print('----------------epoch '+str(epoch+1)+' end---------------------')

epoch: 1 iteration: 1:	 loss: 1.0317931175231934
epoch: 1 iteration: 21:	 loss: 1.0840809345245361
epoch: 1 iteration: 41:	 loss: 1.2600007057189941
epoch: 1 iteration: 61:	 loss: 0.9934634566307068
epoch: 1 iteration: 81:	 loss: 1.2818241119384766
epoch: 1 iteration: 101:	 loss: 1.153946042060852
epoch: 1 iteration: 121:	 loss: 1.1520583629608154
epoch: 1 iteration: 141:	 loss: 1.0106247663497925
epoch: 1 iteration: 161:	 loss: 1.0944043397903442
epoch: 1 iteration: 181:	 loss: 1.1562525033950806
epoch: 1 iteration: 201:	 loss: 1.1064752340316772
epoch: 1 iteration: 221:	 loss: 1.0931023359298706
epoch: 1 iteration: 241:	 loss: 1.0343654155731201
epoch: 1 iteration: 261:	 loss: 1.021734356880188
epoch: 1 iteration: 281:	 loss: 1.071834683418274
epoch: 1 iteration: 301:	 loss: 1.0874319076538086
epoch: 1 iteration: 321:	 loss: 1.0250993967056274
epoch: 1 iteration: 341:	 loss: 1.142384648323059
epoch: 1 iteration: 361:	 loss: 1.102548599243164
epoch: 1 iteration: 381:	 loss: 1.02042317

epoch: 1 iteration: 3201:	 loss: 0.799727201461792
epoch: 1 iteration: 3221:	 loss: 0.8961281180381775
epoch: 1 iteration: 3241:	 loss: 0.8864737153053284
epoch: 1 iteration: 3261:	 loss: 0.8367908000946045
epoch: 1 iteration: 3281:	 loss: 0.8090585470199585
epoch: 1 iteration: 3301:	 loss: 0.8210048675537109
epoch: 1 iteration: 3321:	 loss: 0.8596818447113037
epoch: 1 iteration: 3341:	 loss: 0.8321543335914612
epoch: 1 iteration: 3361:	 loss: 0.9096159338951111
epoch: 1 iteration: 3381:	 loss: 0.7867128849029541
epoch: 1 iteration: 3401:	 loss: 0.8222695589065552
epoch: 1 iteration: 3421:	 loss: 0.6658604145050049
epoch: 1 iteration: 3441:	 loss: 0.8490144610404968
epoch: 1 iteration: 3461:	 loss: 0.8971728086471558
epoch: 1 iteration: 3481:	 loss: 0.8535088300704956
epoch: 1 iteration: 3501:	 loss: 0.8180292844772339
epoch: 1 iteration: 3521:	 loss: 0.7195167541503906
epoch: 1 iteration: 3541:	 loss: 0.8110882639884949
epoch: 1 iteration: 3561:	 loss: 0.802613377571106
epoch: 1 itera

epoch: 2 iteration: 2521:	 loss: 0.7513440251350403
epoch: 2 iteration: 2541:	 loss: 0.7582570910453796
epoch: 2 iteration: 2561:	 loss: 0.835556149482727
epoch: 2 iteration: 2581:	 loss: 0.6985078454017639
epoch: 2 iteration: 2601:	 loss: 0.8165929913520813
epoch: 2 iteration: 2621:	 loss: 0.7433362603187561
epoch: 2 iteration: 2641:	 loss: 0.805871307849884
epoch: 2 iteration: 2661:	 loss: 0.8092523813247681
epoch: 2 iteration: 2681:	 loss: 0.7356683015823364
epoch: 2 iteration: 2701:	 loss: 0.8442575931549072
epoch: 2 iteration: 2721:	 loss: 0.7216255068778992
epoch: 2 iteration: 2741:	 loss: 0.7316858768463135
epoch: 2 iteration: 2761:	 loss: 0.7876774668693542
epoch: 2 iteration: 2781:	 loss: 0.7722417712211609
epoch: 2 iteration: 2801:	 loss: 0.7422717213630676
epoch: 2 iteration: 2821:	 loss: 0.7540525197982788
epoch: 2 iteration: 2841:	 loss: 0.7290620803833008
epoch: 2 iteration: 2861:	 loss: 0.7340658903121948
epoch: 2 iteration: 2881:	 loss: 0.7421233057975769
epoch: 2 itera

epoch: 4 iteration: 2841:	 loss: 0.7031410932540894
epoch: 4 iteration: 2861:	 loss: 0.7203413844108582
epoch: 4 iteration: 2881:	 loss: 0.7229140996932983
epoch: 4 iteration: 2901:	 loss: 0.6826412677764893
epoch: 4 iteration: 2921:	 loss: 0.7854598760604858
epoch: 4 iteration: 2941:	 loss: 0.6853173971176147
epoch: 4 iteration: 2961:	 loss: 0.6784074306488037
epoch: 4 iteration: 2981:	 loss: 0.7418911457061768
epoch: 4 iteration: 3001:	 loss: 0.6989622116088867
epoch: 4 iteration: 3021:	 loss: 0.7201297879219055
epoch: 4 iteration: 3041:	 loss: 0.7005687952041626
epoch: 4 iteration: 3061:	 loss: 0.7240532636642456
epoch: 4 iteration: 3081:	 loss: 0.7324610948562622
epoch: 4 iteration: 3101:	 loss: 0.7100479602813721
epoch: 4 iteration: 3121:	 loss: 0.7154964208602905
epoch: 4 iteration: 3141:	 loss: 0.7039793133735657
epoch: 4 iteration: 3161:	 loss: 0.7241618037223816
epoch: 4 iteration: 3181:	 loss: 0.6765527129173279
epoch: 4 iteration: 3201:	 loss: 0.6946062445640564
epoch: 4 ite

epoch: 5 iteration: 2161:	 loss: 0.7509897351264954
epoch: 5 iteration: 2181:	 loss: 0.6879440546035767
epoch: 5 iteration: 2201:	 loss: 0.6581090688705444
epoch: 5 iteration: 2221:	 loss: 0.693567156791687
epoch: 5 iteration: 2241:	 loss: 0.7157626152038574
epoch: 5 iteration: 2261:	 loss: 0.6827032566070557
epoch: 5 iteration: 2281:	 loss: 0.7047484517097473
epoch: 5 iteration: 2301:	 loss: 0.7058311104774475
epoch: 5 iteration: 2321:	 loss: 0.6741711497306824
epoch: 5 iteration: 2341:	 loss: 0.7214606404304504
epoch: 5 iteration: 2361:	 loss: 0.7177162766456604
epoch: 5 iteration: 2381:	 loss: 0.6710331439971924
epoch: 5 iteration: 2401:	 loss: 0.7055031657218933
epoch: 5 iteration: 2421:	 loss: 0.6973045468330383
epoch: 5 iteration: 2441:	 loss: 0.692441999912262
epoch: 5 iteration: 2461:	 loss: 0.7025246620178223
epoch: 5 iteration: 2481:	 loss: 0.7348504662513733
epoch: 5 iteration: 2501:	 loss: 0.7703509330749512
epoch: 5 iteration: 2521:	 loss: 0.7352365851402283
epoch: 5 itera

epoch: 6 iteration: 1481:	 loss: 0.7097105383872986
epoch: 6 iteration: 1501:	 loss: 0.6916522979736328
epoch: 6 iteration: 1521:	 loss: 0.7173734307289124
epoch: 6 iteration: 1541:	 loss: 0.6826194524765015
epoch: 6 iteration: 1561:	 loss: 0.7290341854095459
epoch: 6 iteration: 1581:	 loss: 0.6926819086074829
epoch: 6 iteration: 1601:	 loss: 0.719490647315979
epoch: 6 iteration: 1621:	 loss: 0.6932433247566223
epoch: 6 iteration: 1641:	 loss: 0.7250713109970093
epoch: 6 iteration: 1661:	 loss: 0.7261515259742737
epoch: 6 iteration: 1681:	 loss: 0.6764593124389648
epoch: 6 iteration: 1701:	 loss: 0.6980302929878235
epoch: 6 iteration: 1721:	 loss: 0.6977190971374512
epoch: 6 iteration: 1741:	 loss: 0.6832029819488525
epoch: 6 iteration: 1761:	 loss: 0.7333337664604187
epoch: 6 iteration: 1781:	 loss: 0.7193721532821655
epoch: 6 iteration: 1801:	 loss: 0.7211495041847229
epoch: 6 iteration: 1821:	 loss: 0.7222370505332947
epoch: 6 iteration: 1841:	 loss: 0.6678541898727417
epoch: 6 iter

epoch: 7 iteration: 821:	 loss: 0.7295017242431641
epoch: 7 iteration: 841:	 loss: 0.694255530834198
epoch: 7 iteration: 861:	 loss: 0.7143446207046509
epoch: 7 iteration: 881:	 loss: 0.6654767990112305
epoch: 7 iteration: 901:	 loss: 0.7120242118835449
epoch: 7 iteration: 921:	 loss: 0.7256877422332764
epoch: 7 iteration: 941:	 loss: 0.7193502187728882
epoch: 7 iteration: 961:	 loss: 0.7003538012504578
epoch: 7 iteration: 981:	 loss: 0.6837663054466248
epoch: 7 iteration: 1001:	 loss: 0.7189288139343262
epoch: 7 iteration: 1021:	 loss: 0.7115572094917297
epoch: 7 iteration: 1041:	 loss: 0.7116002440452576
epoch: 7 iteration: 1061:	 loss: 0.7423796057701111
epoch: 7 iteration: 1081:	 loss: 0.7010359168052673
epoch: 7 iteration: 1101:	 loss: 0.7326087951660156
epoch: 7 iteration: 1121:	 loss: 0.700127899646759
epoch: 7 iteration: 1141:	 loss: 0.694929301738739
epoch: 7 iteration: 1161:	 loss: 0.7184862494468689
epoch: 7 iteration: 1181:	 loss: 0.6521796584129333
epoch: 7 iteration: 1201

epoch: 8 iteration: 141:	 loss: 0.6575555801391602
epoch: 8 iteration: 161:	 loss: 0.6804291009902954
epoch: 8 iteration: 181:	 loss: 0.7076748609542847
epoch: 8 iteration: 201:	 loss: 0.7211918830871582
epoch: 8 iteration: 221:	 loss: 0.7084629535675049
epoch: 8 iteration: 241:	 loss: 0.7000565528869629
epoch: 8 iteration: 261:	 loss: 0.6871145963668823
epoch: 8 iteration: 281:	 loss: 0.717145562171936
epoch: 8 iteration: 301:	 loss: 0.7070290446281433
epoch: 8 iteration: 321:	 loss: 0.7069810032844543
epoch: 8 iteration: 341:	 loss: 0.6872931122779846
epoch: 8 iteration: 361:	 loss: 0.7199509739875793
epoch: 8 iteration: 381:	 loss: 0.6972832083702087
epoch: 8 iteration: 401:	 loss: 0.7106162309646606
epoch: 8 iteration: 421:	 loss: 0.7097024917602539
epoch: 8 iteration: 441:	 loss: 0.7148528695106506
epoch: 8 iteration: 461:	 loss: 0.6947467923164368
epoch: 8 iteration: 481:	 loss: 0.7242609262466431
epoch: 8 iteration: 501:	 loss: 0.7113272547721863
epoch: 8 iteration: 521:	 loss: 

epoch: 8 iteration: 3321:	 loss: 0.6969917416572571
epoch: 8 iteration: 3341:	 loss: 0.7223015427589417
epoch: 8 iteration: 3361:	 loss: 0.6984889507293701
epoch: 8 iteration: 3381:	 loss: 0.6921043395996094
epoch: 8 iteration: 3401:	 loss: 0.6854057908058167
epoch: 8 iteration: 3421:	 loss: 0.6882438063621521
epoch: 8 iteration: 3441:	 loss: 0.6886479258537292
epoch: 8 iteration: 3461:	 loss: 0.6809309124946594
epoch: 8 iteration: 3481:	 loss: 0.7095460295677185
epoch: 8 iteration: 3501:	 loss: 0.715329647064209
epoch: 8 iteration: 3521:	 loss: 0.6791632175445557
epoch: 8 iteration: 3541:	 loss: 0.6725143790245056
epoch: 8 iteration: 3561:	 loss: 0.6793286204338074
epoch: 8 iteration: 3581:	 loss: 0.7085526585578918
epoch: 8 iteration: 3601:	 loss: 0.7034538388252258
epoch: 8 iteration: 3621:	 loss: 0.7010341882705688
epoch: 8 iteration: 3641:	 loss: 0.7140025496482849
epoch: 8 iteration: 3661:	 loss: 0.7316170334815979
epoch: 8 iteration: 3681:	 loss: 0.6933761835098267
epoch: 8 iter

epoch: 9 iteration: 2641:	 loss: 0.7035819888114929
epoch: 9 iteration: 2661:	 loss: 0.7204104065895081
epoch: 9 iteration: 2681:	 loss: 0.6697747707366943
epoch: 9 iteration: 2701:	 loss: 0.6872157454490662
epoch: 9 iteration: 2721:	 loss: 0.6927103996276855
epoch: 9 iteration: 2741:	 loss: 0.6963769197463989
epoch: 9 iteration: 2761:	 loss: 0.7158077955245972
epoch: 9 iteration: 2781:	 loss: 0.6685027480125427
epoch: 9 iteration: 2801:	 loss: 0.6915385723114014
epoch: 9 iteration: 2821:	 loss: 0.690886914730072
epoch: 9 iteration: 2841:	 loss: 0.7014902830123901
epoch: 9 iteration: 2861:	 loss: 0.6913484334945679
epoch: 9 iteration: 2881:	 loss: 0.6841951012611389
epoch: 9 iteration: 2901:	 loss: 0.668598473072052
epoch: 9 iteration: 2921:	 loss: 0.6796818971633911
epoch: 9 iteration: 2941:	 loss: 0.6749527454376221
epoch: 9 iteration: 2961:	 loss: 0.738035261631012
epoch: 9 iteration: 2981:	 loss: 0.6978758573532104
epoch: 9 iteration: 3001:	 loss: 0.6926296949386597
epoch: 9 iterat

epoch: 10 iteration: 1921:	 loss: 0.6765667200088501
epoch: 10 iteration: 1941:	 loss: 0.7186652421951294
epoch: 10 iteration: 1961:	 loss: 0.7123453617095947
epoch: 10 iteration: 1981:	 loss: 0.6934832334518433
epoch: 10 iteration: 2001:	 loss: 0.710435152053833
epoch: 10 iteration: 2021:	 loss: 0.6940099000930786
epoch: 10 iteration: 2041:	 loss: 0.7125563025474548
epoch: 10 iteration: 2061:	 loss: 0.7185832858085632
epoch: 10 iteration: 2081:	 loss: 0.711207389831543
epoch: 10 iteration: 2101:	 loss: 0.6900940537452698
epoch: 10 iteration: 2121:	 loss: 0.6513157486915588
epoch: 10 iteration: 2141:	 loss: 0.7061230540275574
epoch: 10 iteration: 2161:	 loss: 0.6783627271652222
epoch: 10 iteration: 2181:	 loss: 0.6870464086532593
epoch: 10 iteration: 2201:	 loss: 0.7124716639518738
epoch: 10 iteration: 2221:	 loss: 0.6896898746490479
epoch: 10 iteration: 2241:	 loss: 0.6909124851226807
epoch: 10 iteration: 2261:	 loss: 0.6787837743759155
epoch: 10 iteration: 2281:	 loss: 0.710066497325

epoch: 11 iteration: 1181:	 loss: 0.6682658195495605
epoch: 11 iteration: 1201:	 loss: 0.6899187564849854
epoch: 11 iteration: 1221:	 loss: 0.6888179183006287
epoch: 11 iteration: 1241:	 loss: 0.686010479927063
epoch: 11 iteration: 1261:	 loss: 0.6620437502861023
epoch: 11 iteration: 1281:	 loss: 0.6814656257629395
epoch: 11 iteration: 1301:	 loss: 0.6466128826141357
epoch: 11 iteration: 1321:	 loss: 0.6843525171279907
epoch: 11 iteration: 1341:	 loss: 0.6949517130851746
epoch: 11 iteration: 1361:	 loss: 0.7095741033554077
epoch: 11 iteration: 1381:	 loss: 0.6880443692207336
epoch: 11 iteration: 1401:	 loss: 0.6589187383651733
epoch: 11 iteration: 1421:	 loss: 0.6860829591751099
epoch: 11 iteration: 1441:	 loss: 0.6944498419761658
epoch: 11 iteration: 1461:	 loss: 0.6962052583694458
epoch: 11 iteration: 1481:	 loss: 0.7002667784690857
epoch: 11 iteration: 1501:	 loss: 0.6942528486251831
epoch: 11 iteration: 1521:	 loss: 0.6870553493499756
epoch: 11 iteration: 1541:	 loss: 0.68620616197

epoch: 12 iteration: 441:	 loss: 0.6970033049583435
epoch: 12 iteration: 461:	 loss: 0.6852786540985107
epoch: 12 iteration: 481:	 loss: 0.6983580589294434
epoch: 12 iteration: 501:	 loss: 0.6853769421577454
epoch: 12 iteration: 521:	 loss: 0.673386812210083
epoch: 12 iteration: 541:	 loss: 0.6817893981933594
epoch: 12 iteration: 561:	 loss: 0.6733483672142029
epoch: 12 iteration: 581:	 loss: 0.6866641640663147
epoch: 12 iteration: 601:	 loss: 0.6658166646957397
epoch: 12 iteration: 621:	 loss: 0.7126324772834778
epoch: 12 iteration: 641:	 loss: 0.6765681505203247
epoch: 12 iteration: 661:	 loss: 0.6698620319366455
epoch: 12 iteration: 681:	 loss: 0.680462658405304
epoch: 12 iteration: 701:	 loss: 0.7071147561073303
epoch: 12 iteration: 721:	 loss: 0.699354350566864
epoch: 12 iteration: 741:	 loss: 0.6798421144485474
epoch: 12 iteration: 761:	 loss: 0.6913890838623047
epoch: 12 iteration: 781:	 loss: 0.6751869916915894
epoch: 12 iteration: 801:	 loss: 0.705782413482666
epoch: 12 iterat

epoch: 12 iteration: 3561:	 loss: 0.6965742707252502
epoch: 12 iteration: 3581:	 loss: 0.7014415264129639
epoch: 12 iteration: 3601:	 loss: 0.6794609427452087
epoch: 12 iteration: 3621:	 loss: 0.698805570602417
epoch: 12 iteration: 3641:	 loss: 0.6934251189231873
epoch: 12 iteration: 3661:	 loss: 0.6529586911201477
epoch: 12 iteration: 3681:	 loss: 0.6981339454650879
epoch: 12 iteration: 3701:	 loss: 0.6933721899986267
epoch: 12 iteration: 3721:	 loss: 0.6816564798355103
epoch: 12 iteration: 3741:	 loss: 0.6851411461830139
epoch: 12 iteration: 3761:	 loss: 0.7058091759681702
epoch: 12 iteration: 3781:	 loss: 0.6640170216560364
epoch: 12 iteration: 3801:	 loss: 0.71771240234375
EarlyStopping counter: 1 out of 2
----------------epoch 12 end---------------------
epoch: 13 iteration: 1:	 loss: 0.676289439201355
epoch: 13 iteration: 21:	 loss: 0.7042034864425659
epoch: 13 iteration: 41:	 loss: 0.7057744264602661
epoch: 13 iteration: 61:	 loss: 0.676931619644165
epoch: 13 iteration: 81:	 los

epoch: 13 iteration: 2841:	 loss: 0.6955357789993286
epoch: 13 iteration: 2861:	 loss: 0.7028803825378418
epoch: 13 iteration: 2881:	 loss: 0.6921534538269043
epoch: 13 iteration: 2901:	 loss: 0.6858015656471252
epoch: 13 iteration: 2921:	 loss: 0.7132366895675659
epoch: 13 iteration: 2941:	 loss: 0.7064408659934998
epoch: 13 iteration: 2961:	 loss: 0.6984405517578125
epoch: 13 iteration: 2981:	 loss: 0.6738322973251343
epoch: 13 iteration: 3001:	 loss: 0.6929647922515869
epoch: 13 iteration: 3021:	 loss: 0.6862084865570068
epoch: 13 iteration: 3041:	 loss: 0.7089731693267822
epoch: 13 iteration: 3061:	 loss: 0.7124587893486023
epoch: 13 iteration: 3081:	 loss: 0.6916874051094055
epoch: 13 iteration: 3101:	 loss: 0.716401219367981
epoch: 13 iteration: 3121:	 loss: 0.707781195640564
epoch: 13 iteration: 3141:	 loss: 0.6658381223678589
epoch: 13 iteration: 3161:	 loss: 0.6590619683265686
epoch: 13 iteration: 3181:	 loss: 0.6808151006698608
epoch: 13 iteration: 3201:	 loss: 0.705850422382

epoch: 14 iteration: 2101:	 loss: 0.6743864417076111
epoch: 14 iteration: 2121:	 loss: 0.6945976614952087
epoch: 14 iteration: 2141:	 loss: 0.6965017914772034
epoch: 14 iteration: 2161:	 loss: 0.6986355781555176
epoch: 14 iteration: 2181:	 loss: 0.6892490386962891
epoch: 14 iteration: 2201:	 loss: 0.7217488288879395
epoch: 14 iteration: 2221:	 loss: 0.7006090879440308
epoch: 14 iteration: 2241:	 loss: 0.6734192371368408
epoch: 14 iteration: 2261:	 loss: 0.703910768032074
epoch: 14 iteration: 2281:	 loss: 0.6634453535079956
epoch: 14 iteration: 2301:	 loss: 0.6861295104026794
epoch: 14 iteration: 2321:	 loss: 0.6955384612083435
epoch: 14 iteration: 2341:	 loss: 0.6860319375991821
epoch: 14 iteration: 2361:	 loss: 0.6816280484199524
epoch: 14 iteration: 2381:	 loss: 0.6925039291381836
epoch: 14 iteration: 2401:	 loss: 0.66129469871521
epoch: 14 iteration: 2421:	 loss: 0.680393397808075
epoch: 14 iteration: 2441:	 loss: 0.6501152515411377
epoch: 14 iteration: 2461:	 loss: 0.68335449695587

epoch: 15 iteration: 1381:	 loss: 0.6769047975540161
epoch: 15 iteration: 1401:	 loss: 0.6778184771537781
epoch: 15 iteration: 1421:	 loss: 0.676692545413971
epoch: 15 iteration: 1441:	 loss: 0.6796646118164062
epoch: 15 iteration: 1461:	 loss: 0.654241144657135
epoch: 15 iteration: 1481:	 loss: 0.6880954504013062
epoch: 15 iteration: 1501:	 loss: 0.6962665915489197
epoch: 15 iteration: 1521:	 loss: 0.6692314147949219
epoch: 15 iteration: 1541:	 loss: 0.672501266002655
epoch: 15 iteration: 1561:	 loss: 0.6602296233177185
epoch: 15 iteration: 1581:	 loss: 0.711499035358429
epoch: 15 iteration: 1601:	 loss: 0.6662774085998535
epoch: 15 iteration: 1621:	 loss: 0.7020213603973389
epoch: 15 iteration: 1641:	 loss: 0.6775001883506775
epoch: 15 iteration: 1661:	 loss: 0.6915495991706848
epoch: 15 iteration: 1681:	 loss: 0.7263628244400024
epoch: 15 iteration: 1701:	 loss: 0.6601253747940063
epoch: 15 iteration: 1721:	 loss: 0.6844063997268677
epoch: 15 iteration: 1741:	 loss: 0.68171560764312

In [23]:
torch.save(model,'./cnn_model/model_v02.pth')

In [24]:
pd.Series(lr).to_csv("./cnn_model/lr.txt")

In [None]:
# # check if validation loss has improved
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             counter = 0
#         else:
#             counter += 1
#             if counter >= patience:
#                 print("Early stopping!")
#                 break

In [None]:
plt.figure(figsize=(10,5))
plt.title('PyTorch_CNN_Loss')
plt.plot(loss_count,label='Loss')
plt.plot(pd.DataFrame(np.array(loss_count)).rolling(20,1).mean(),label='MA Loss')
plt.legend()
# plt.savefig('./cnn_model/CNN_classification_loss.png')
plt.show()

## 对模型的修正，和论文方法对齐

In [6]:
name='hhu'
print(f"{'zzzzzzzzzzzz':_^20}")

name='zzzzzz'
print(f"{'name':_^20}")

____zzzzzzzzzzzz____
________name________
