In [6]:
#train3 from 2: use adam, remove weight_decay; one BW channel (not 3)
import torch,os,glob
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision.transforms import v2
from torchvision import tv_tensors
import pandas as pd
import numpy as np

from nets import ResNet18_3lbC_nlbM
from tqdm import tqdm
from configparser import ConfigParser
from torch.utils.data import  DataLoader
from LIDC_M_data import LIDC_Dataset
import pickle

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
best_epoch = 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

prep_tr = [
    v2.Lambda(lambda x: tv_tensors.Image(torch.clamp(x,-1000.,400.)) if isinstance(x, tv_tensors.Image) else x),
    v2.Lambda(lambda x: tv_tensors.Image((x+1000)/1400) if isinstance(x, tv_tensors.Image) else x),
    v2.CenterCrop((384,384)),
]
aug_tr = [
    v2.RandomAffine(degrees=10),
    v2.RandomHorizontalFlip(),
]
trans_train = v2.Compose( prep_tr + aug_tr )
trans_test = v2.Compose( prep_tr  )

In [8]:

parser = ConfigParser()
parser.read('.settings')
root_dir = parser.get('dataset','root_dir') #/workspaces/data/lidc-idri/slices
meta_dir = parser.get('dataset','meta_dir') #/workspaces/data/lidc-idri/splits
result_dir = os.path.join(parser.get('dataset','result_dir'),'stage2/basel0_3lbC_2lbM')
if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

train_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'trainBB_malB.csv'),transform=trans_train, loadBB=True)
test_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'testBB_malB.csv'),transform=trans_test)
total_train_data = len(train_data)
total_test_data = len(test_data)
print('total_train_data:',total_train_data, 'total_test_data:',total_test_data)

batch_size = int(parser['dataset']['batch_size'])
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
testloader = DataLoader(test_data, batch_size=batch_size, num_workers=8)

total_train_data: 5495 total_test_data: 2354


In [9]:
net = ResNet18_3lbC_nlbM(pretrained=True,attr="MGA2")
net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.fc = nn.Linear(net.fc.in_features, 2)
net = net.to(device)

In [10]:
from torchinfo import summary
summary(net, input_size=(batch_size,1, 384, 384))

Layer (type:depth-idx)                             Output Shape              Param #
ResNet18_3lbC_nlbM                                 [32, 2]                   --
├─Conv2d: 1-1                                      [32, 64, 192, 192]        3,136
├─BatchNorm2d: 1-2                                 [32, 64, 192, 192]        128
├─ReLU: 1-3                                        [32, 64, 192, 192]        --
├─MaxPool2d: 1-4                                   [32, 64, 96, 96]          --
├─Sequential: 1-5                                  [32, 64, 96, 96]          --
│    └─BasicBlock: 2-1                             [32, 64, 96, 96]          --
│    │    └─Conv2d: 3-1                            [32, 64, 96, 96]          36,864
│    │    └─BatchNorm2d: 3-2                       [32, 64, 96, 96]          128
│    │    └─ReLU: 3-3                              [32, 64, 96, 96]          --
│    │    └─Conv2d: 3-4                            [32, 64, 96, 96]          36,864
│    │    └─BatchNorm2

In [11]:
lr = 1e-4
criterion = nn.CrossEntropyLoss()
mse = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

training_info=[["epoch","acc","loss"]]
testing_info=training_info.copy()

In [12]:
def train(epoch):
    net.train()
    train_loss = np.zeros(3)
    correct = 0
    total = 0
    pbar = tqdm(trainloader)
    for batch_idx, (inputs, targets, masks) in enumerate(pbar):
        inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)

        optimizer.zero_grad()
        # outputs = net(inputs)
        outputs, attn_map = net(inputs)
        cls_loss = criterion(outputs, targets)
        masks = F.adaptive_avg_pool2d(masks, attn_map.shape[-2:])
        att_loss = 100.* mse(attn_map , masks)
        
        loss =  cls_loss + att_loss
        
        # loss = criterion(outputs, targets)
        loss.backward()
        
        optimizer.step()

        train_loss += np.array([loss.item(), cls_loss.item(), att_loss.item()])
        _, predicted = outputs.max(1)
        
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(f"Epoch: {epoch} Acc: {(100.*correct/total):.2f}")

    train_acc = 100.*correct/total
    train_loss = train_loss/(batch_idx+1)
    print(f"Tot Loss: {train_loss[0]:.4f} CL: {train_loss[1]:.5f} AT: {train_loss[2]:.5f}; Train Acc: {train_acc:.2f}%")
    training_info.append([epoch,train_acc,train_loss])
    # trainning_accuracy.append(train_acc)
    # trainning_loss.append( train_loss )

def test(epoch, islast = False):
    global best_acc, best_epoch
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(testloader)
    with torch.no_grad():
        for batch_idx, (inputs, targets ) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        test_acc = 100.*correct/total
        test_loss = test_loss/(batch_idx+1)
        print(f"Test Loss: {test_loss}, Test Acc: {test_acc:.2f}%")
        testing_info.append([epoch,test_acc,test_loss])
        # testing_accuracy.append(test_acc)
        # testing_loss.append(test_loss)
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc or islast:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        savestr = 'best' if acc > best_acc else 'last'
        torch.save(state, os.path.join(result_dir,f'basel0-b{batch_size}-epoch{epoch}-{savestr}.pth'))
        best_acc = acc
        best_epoch = epoch

In [13]:
aa = np.zeros(3)
aa = [1,2,3]

In [14]:
print(aa)

[1, 2, 3]


In [15]:
if start_epoch>0:
    checkpoint = torch.load(glob.glob(os.path.join(result_dir,f'basel0-b{batch_size}-epoch{start_epoch-1}-*.pth'))[0] )
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']

for epoch in range(start_epoch, start_epoch+50):
    train(epoch)
    test(epoch, islast = epoch==start_epoch+49)

Epoch: 0 Acc: 57.71: 100%|██████████| 172/172 [00:43<00:00,  3.94it/s]


Tot Loss: 22.6644 CL: 0.68351 AT: 21.98084; Train Acc: 57.71%


100%|██████████| 74/74 [00:06<00:00, 12.10it/s]


Test Loss: 0.670093152974103, Test Acc: 58.54%
Saving..


Epoch: 1 Acc: 61.73: 100%|██████████| 172/172 [00:42<00:00,  4.02it/s]


Tot Loss: 4.6948 CL: 0.65887 AT: 4.03596; Train Acc: 61.73%


100%|██████████| 74/74 [00:06<00:00, 12.25it/s]


Test Loss: 0.6307749921405638, Test Acc: 64.40%
Saving..


Epoch: 2 Acc: 63.49: 100%|██████████| 172/172 [00:44<00:00,  3.89it/s]


Tot Loss: 1.5586 CL: 0.63578 AT: 0.92278; Train Acc: 63.49%


100%|██████████| 74/74 [00:05<00:00, 12.76it/s]


Test Loss: 0.6371134004077396, Test Acc: 63.38%


Epoch: 3 Acc: 66.46: 100%|██████████| 172/172 [00:41<00:00,  4.14it/s]


Tot Loss: 1.0769 CL: 0.61070 AT: 0.46621; Train Acc: 66.46%


100%|██████████| 74/74 [00:06<00:00, 11.79it/s]


Test Loss: 0.6051629378988936, Test Acc: 66.82%
Saving..


Epoch: 4 Acc: 69.23: 100%|██████████| 172/172 [00:42<00:00,  4.00it/s]


Tot Loss: 0.8868 CL: 0.58382 AT: 0.30299; Train Acc: 69.23%


100%|██████████| 74/74 [00:05<00:00, 12.36it/s]


Test Loss: 0.587622046067908, Test Acc: 69.24%
Saving..


Epoch: 5 Acc: 70.99: 100%|██████████| 172/172 [00:44<00:00,  3.82it/s]


Tot Loss: 0.7795 CL: 0.55326 AT: 0.22627; Train Acc: 70.99%


100%|██████████| 74/74 [00:09<00:00,  8.15it/s]


Test Loss: 0.5627086033692231, Test Acc: 71.20%
Saving..


Epoch: 6 Acc: 73.74: 100%|██████████| 172/172 [00:42<00:00,  4.08it/s]


Tot Loss: 0.7076 CL: 0.52913 AT: 0.17850; Train Acc: 73.74%


100%|██████████| 74/74 [00:08<00:00,  8.92it/s]


Test Loss: 0.5424236292774612, Test Acc: 72.85%
Saving..


Epoch: 7 Acc: 76.96: 100%|██████████| 172/172 [00:41<00:00,  4.13it/s]


Tot Loss: 0.6365 CL: 0.48627 AT: 0.15026; Train Acc: 76.96%


100%|██████████| 74/74 [00:12<00:00,  5.86it/s]


Test Loss: 0.6079249176624659, Test Acc: 71.33%


Epoch: 8 Acc: 77.91: 100%|██████████| 172/172 [00:40<00:00,  4.26it/s]


Tot Loss: 0.6006 CL: 0.46730 AT: 0.13330; Train Acc: 77.91%


100%|██████████| 74/74 [00:07<00:00, 10.02it/s]


Test Loss: 0.5466922509509164, Test Acc: 73.45%
Saving..


Epoch: 9 Acc: 79.95: 100%|██████████| 172/172 [00:43<00:00,  3.96it/s]


Tot Loss: 0.5542 CL: 0.43252 AT: 0.12166; Train Acc: 79.95%


100%|██████████| 74/74 [00:05<00:00, 12.53it/s]


Test Loss: 0.5279919738705093, Test Acc: 74.60%
Saving..


Epoch: 10 Acc: 82.00: 100%|██████████| 172/172 [00:49<00:00,  3.49it/s]


Tot Loss: 0.5141 CL: 0.40119 AT: 0.11294; Train Acc: 82.00%


100%|██████████| 74/74 [00:10<00:00,  6.80it/s]


Test Loss: 0.51841977721936, Test Acc: 76.89%
Saving..


Epoch: 11 Acc: 84.06: 100%|██████████| 172/172 [00:41<00:00,  4.14it/s]


Tot Loss: 0.4700 CL: 0.36337 AT: 0.10666; Train Acc: 84.06%


100%|██████████| 74/74 [00:08<00:00,  8.54it/s]


Test Loss: 0.5407171696424484, Test Acc: 76.72%


Epoch: 12 Acc: 84.71: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Tot Loss: 0.4492 CL: 0.34751 AT: 0.10172; Train Acc: 84.71%


100%|██████████| 74/74 [00:10<00:00,  6.97it/s]


Test Loss: 0.46382650612173854, Test Acc: 80.12%
Saving..


Epoch: 13 Acc: 86.26: 100%|██████████| 172/172 [00:42<00:00,  4.05it/s]


Tot Loss: 0.4277 CL: 0.33057 AT: 0.09713; Train Acc: 86.26%


100%|██████████| 74/74 [00:27<00:00,  2.72it/s]


Test Loss: 0.4431141109482662, Test Acc: 81.52%
Saving..


Epoch: 14 Acc: 87.22: 100%|██████████| 172/172 [00:44<00:00,  3.88it/s]


Tot Loss: 0.4127 CL: 0.31877 AT: 0.09396; Train Acc: 87.22%


100%|██████████| 74/74 [00:05<00:00, 12.58it/s]


Test Loss: 0.42312234076293737, Test Acc: 82.75%
Saving..


Epoch: 15 Acc: 88.10: 100%|██████████| 172/172 [00:43<00:00,  3.93it/s]


Tot Loss: 0.3819 CL: 0.29057 AT: 0.09129; Train Acc: 88.10%


100%|██████████| 74/74 [00:06<00:00, 11.91it/s]


Test Loss: 0.4590800485095462, Test Acc: 81.56%


Epoch: 16 Acc: 89.03: 100%|██████████| 172/172 [00:44<00:00,  3.86it/s]


Tot Loss: 0.3502 CL: 0.26170 AT: 0.08853; Train Acc: 89.03%


100%|██████████| 74/74 [00:12<00:00,  5.82it/s]


Test Loss: 0.5116811857835667, Test Acc: 80.20%


Epoch: 17 Acc: 89.74: 100%|██████████| 172/172 [00:41<00:00,  4.19it/s]


Tot Loss: 0.3418 CL: 0.25530 AT: 0.08652; Train Acc: 89.74%


100%|██████████| 74/74 [00:06<00:00, 10.81it/s]


Test Loss: 0.46954007265535563, Test Acc: 82.63%


Epoch: 18 Acc: 90.65: 100%|██████████| 172/172 [00:46<00:00,  3.69it/s]


Tot Loss: 0.3187 CL: 0.23433 AT: 0.08435; Train Acc: 90.65%


100%|██████████| 74/74 [00:08<00:00,  9.21it/s]


Test Loss: 0.39848584369630424, Test Acc: 84.58%
Saving..


Epoch: 19 Acc: 91.26: 100%|██████████| 172/172 [00:41<00:00,  4.10it/s]


Tot Loss: 0.2997 CL: 0.21737 AT: 0.08235; Train Acc: 91.26%


100%|██████████| 74/74 [00:09<00:00,  8.05it/s]


Test Loss: 0.434001741175716, Test Acc: 84.11%


Epoch: 20 Acc: 91.88: 100%|██████████| 172/172 [00:46<00:00,  3.72it/s]


Tot Loss: 0.2893 CL: 0.20879 AT: 0.08050; Train Acc: 91.88%


100%|██████████| 74/74 [00:10<00:00,  6.98it/s]


Test Loss: 0.4593456877848586, Test Acc: 82.97%


Epoch: 21 Acc: 91.72: 100%|██████████| 172/172 [00:51<00:00,  3.34it/s]


Tot Loss: 0.2900 CL: 0.21105 AT: 0.07890; Train Acc: 91.72%


100%|██████████| 74/74 [00:07<00:00, 10.45it/s]


Test Loss: 0.42636362694807955, Test Acc: 83.77%


Epoch: 22 Acc: 92.50: 100%|██████████| 172/172 [00:53<00:00,  3.22it/s]


Tot Loss: 0.2703 CL: 0.19305 AT: 0.07726; Train Acc: 92.50%


100%|██████████| 74/74 [00:06<00:00, 11.99it/s]


Test Loss: 0.4203342869877815, Test Acc: 83.98%


Epoch: 23 Acc: 92.61: 100%|██████████| 172/172 [00:41<00:00,  4.11it/s]


Tot Loss: 0.2618 CL: 0.18625 AT: 0.07558; Train Acc: 92.61%


100%|██████████| 74/74 [00:06<00:00, 11.35it/s]


Test Loss: 0.41095291365038705, Test Acc: 85.30%
Saving..


Epoch: 24 Acc: 93.56: 100%|██████████| 172/172 [00:48<00:00,  3.52it/s]


Tot Loss: 0.2423 CL: 0.16782 AT: 0.07451; Train Acc: 93.56%


100%|██████████| 74/74 [00:12<00:00,  6.08it/s]


Test Loss: 0.499245225376374, Test Acc: 81.65%


Epoch: 25 Acc: 92.68: 100%|██████████| 172/172 [00:44<00:00,  3.89it/s]


Tot Loss: 0.2630 CL: 0.18958 AT: 0.07341; Train Acc: 92.68%


100%|██████████| 74/74 [00:09<00:00,  7.57it/s]


Test Loss: 0.482177708700702, Test Acc: 83.73%


Epoch: 26 Acc: 94.14: 100%|██████████| 172/172 [00:41<00:00,  4.13it/s]


Tot Loss: 0.2299 CL: 0.15771 AT: 0.07223; Train Acc: 94.14%


100%|██████████| 74/74 [00:06<00:00, 11.03it/s]


Test Loss: 0.43326170923742086, Test Acc: 85.81%
Saving..


Epoch: 27 Acc: 94.21: 100%|██████████| 172/172 [00:50<00:00,  3.43it/s]


Tot Loss: 0.2276 CL: 0.15647 AT: 0.07115; Train Acc: 94.21%


100%|██████████| 74/74 [00:06<00:00, 11.57it/s]


Test Loss: 0.4218609616965861, Test Acc: 86.36%
Saving..


Epoch: 28 Acc: 94.21: 100%|██████████| 172/172 [00:44<00:00,  3.83it/s]


Tot Loss: 0.2115 CL: 0.14158 AT: 0.06991; Train Acc: 94.21%


100%|██████████| 74/74 [00:09<00:00,  8.17it/s]


Test Loss: 0.4502886432550243, Test Acc: 85.30%


Epoch: 29 Acc: 94.74: 100%|██████████| 172/172 [00:52<00:00,  3.26it/s]


Tot Loss: 0.2097 CL: 0.14072 AT: 0.06896; Train Acc: 94.74%


100%|██████████| 74/74 [00:06<00:00, 12.07it/s]


Test Loss: 0.4788712030528365, Test Acc: 85.60%


Epoch: 30 Acc: 95.40: 100%|██████████| 172/172 [00:44<00:00,  3.86it/s]


Tot Loss: 0.2036 CL: 0.13552 AT: 0.06809; Train Acc: 95.40%


100%|██████████| 74/74 [00:08<00:00,  8.63it/s]


Test Loss: 0.5106841858375717, Test Acc: 85.26%


Epoch: 31 Acc: 95.32: 100%|██████████| 172/172 [00:53<00:00,  3.19it/s]


Tot Loss: 0.1898 CL: 0.12263 AT: 0.06716; Train Acc: 95.32%


100%|██████████| 74/74 [00:06<00:00, 11.52it/s]


Test Loss: 0.42950098186328606, Test Acc: 86.62%
Saving..


Epoch: 32 Acc: 94.94: 100%|██████████| 172/172 [00:56<00:00,  3.02it/s]


Tot Loss: 0.1974 CL: 0.13042 AT: 0.06698; Train Acc: 94.94%


100%|██████████| 74/74 [00:09<00:00,  7.80it/s]


Test Loss: 0.4626842131083076, Test Acc: 85.77%


Epoch: 33 Acc: 95.43: 100%|██████████| 172/172 [00:45<00:00,  3.74it/s]


Tot Loss: 0.1936 CL: 0.12715 AT: 0.06642; Train Acc: 95.43%


100%|██████████| 74/74 [00:08<00:00,  8.39it/s]


Test Loss: 0.48089997097849846, Test Acc: 84.75%


Epoch: 34 Acc: 95.05: 100%|██████████| 172/172 [00:48<00:00,  3.55it/s]


Tot Loss: 0.1964 CL: 0.13077 AT: 0.06564; Train Acc: 95.05%


100%|██████████| 74/74 [00:07<00:00, 10.52it/s]


Test Loss: 0.44801409240510015, Test Acc: 86.28%


Epoch: 35 Acc: 96.05: 100%|██████████| 172/172 [00:43<00:00,  3.92it/s]


Tot Loss: 0.1800 CL: 0.11529 AT: 0.06470; Train Acc: 96.05%


100%|██████████| 74/74 [00:07<00:00, 10.53it/s]


Test Loss: 0.4857729294815579, Test Acc: 86.96%
Saving..


Epoch: 36 Acc: 95.94: 100%|██████████| 172/172 [00:48<00:00,  3.56it/s]


Tot Loss: 0.1769 CL: 0.11313 AT: 0.06373; Train Acc: 95.94%


100%|██████████| 74/74 [00:08<00:00,  8.29it/s]


Test Loss: 0.5364297131026113, Test Acc: 84.28%


Epoch: 37 Acc: 96.27: 100%|██████████| 172/172 [00:44<00:00,  3.89it/s]


Tot Loss: 0.1696 CL: 0.10630 AT: 0.06328; Train Acc: 96.27%


100%|██████████| 74/74 [00:06<00:00, 11.55it/s]


Test Loss: 0.4857939908633361, Test Acc: 85.22%


Epoch: 38 Acc: 96.32: 100%|██████████| 172/172 [00:48<00:00,  3.53it/s]


Tot Loss: 0.1644 CL: 0.10125 AT: 0.06311; Train Acc: 96.32%


100%|██████████| 74/74 [00:07<00:00,  9.84it/s]


Test Loss: 0.5066844944313571, Test Acc: 86.32%


Epoch: 39 Acc: 95.38: 100%|██████████| 172/172 [00:41<00:00,  4.10it/s]


Tot Loss: 0.1870 CL: 0.12454 AT: 0.06247; Train Acc: 95.38%


100%|██████████| 74/74 [00:05<00:00, 12.56it/s]


Test Loss: 0.4671507259780491, Test Acc: 85.39%


Epoch: 40 Acc: 96.09: 100%|██████████| 172/172 [00:46<00:00,  3.73it/s]


Tot Loss: 0.1635 CL: 0.10135 AT: 0.06220; Train Acc: 96.09%


100%|██████████| 74/74 [00:06<00:00, 10.59it/s]


Test Loss: 0.4759143031378453, Test Acc: 85.47%


Epoch: 41 Acc: 96.82: 100%|██████████| 172/172 [00:41<00:00,  4.10it/s]


Tot Loss: 0.1457 CL: 0.08440 AT: 0.06131; Train Acc: 96.82%


100%|██████████| 74/74 [00:11<00:00,  6.23it/s]


Test Loss: 0.5422699629656367, Test Acc: 84.15%


Epoch: 42 Acc: 96.23: 100%|██████████| 172/172 [00:42<00:00,  4.05it/s]


Tot Loss: 0.1618 CL: 0.10103 AT: 0.06076; Train Acc: 96.23%


100%|██████████| 74/74 [00:07<00:00,  9.74it/s]


Test Loss: 0.5125156018782306, Test Acc: 85.26%


Epoch: 43 Acc: 96.60: 100%|██████████| 172/172 [00:54<00:00,  3.14it/s]


Tot Loss: 0.1546 CL: 0.09428 AT: 0.06034; Train Acc: 96.60%


100%|██████████| 74/74 [00:06<00:00, 10.73it/s]


Test Loss: 0.4734225218008096, Test Acc: 86.83%


Epoch: 44 Acc: 96.54: 100%|██████████| 172/172 [00:42<00:00,  4.07it/s]


Tot Loss: 0.1516 CL: 0.09186 AT: 0.05972; Train Acc: 96.54%


100%|██████████| 74/74 [00:09<00:00,  8.13it/s]


Test Loss: 0.46447033576063207, Test Acc: 86.36%


Epoch: 45 Acc: 96.67: 100%|██████████| 172/172 [00:43<00:00,  3.95it/s]


Tot Loss: 0.1471 CL: 0.08755 AT: 0.05951; Train Acc: 96.67%


100%|██████████| 74/74 [00:05<00:00, 12.41it/s]


Test Loss: 0.4683013921351852, Test Acc: 86.70%


Epoch: 46 Acc: 97.02: 100%|██████████| 172/172 [00:42<00:00,  4.05it/s]


Tot Loss: 0.1408 CL: 0.08180 AT: 0.05897; Train Acc: 97.02%


100%|██████████| 74/74 [00:08<00:00,  8.48it/s]


Test Loss: 0.5002701163191248, Test Acc: 85.56%


Epoch: 47 Acc: 96.67: 100%|██████████| 172/172 [00:42<00:00,  4.00it/s]


Tot Loss: 0.1469 CL: 0.08810 AT: 0.05883; Train Acc: 96.67%


100%|██████████| 74/74 [00:14<00:00,  5.15it/s]


Test Loss: 0.5744739015762871, Test Acc: 84.54%


Epoch: 48 Acc: 97.02: 100%|██████████| 172/172 [00:45<00:00,  3.74it/s]


Tot Loss: 0.1405 CL: 0.08253 AT: 0.05795; Train Acc: 97.02%


100%|██████████| 74/74 [00:07<00:00, 10.01it/s]


Test Loss: 0.5295427970185473, Test Acc: 86.15%


Epoch: 49 Acc: 96.94: 100%|██████████| 172/172 [00:46<00:00,  3.72it/s]


Tot Loss: 0.1384 CL: 0.08076 AT: 0.05761; Train Acc: 96.94%


100%|██████████| 74/74 [00:09<00:00,  7.57it/s]


Test Loss: 0.5191629845649004, Test Acc: 85.98%
Saving..


In [16]:
traindf = pd.DataFrame(training_info[1:],columns=training_info[0])
testdf = pd.DataFrame(testing_info[1:],columns=testing_info[0])
with open(os.path.join(result_dir,f'basel0-b{batch_size}-info.pkl'), 'wb') as file:
    pickle.dump({"train":traindf,"test":testdf}, file)

 
 #   scheduler.step()

In [17]:
import re
levels = re.findall(r'MGA(\d+)','MGA22')
mga_level = int(levels[0]) if levels else 3
# 'MGA3'.startswith('MGA')
mga_level

22

In [18]:
testdf

Unnamed: 0,epoch,acc,loss
0,0,58.538658,0.670093
1,1,64.40102,0.630775
2,2,63.381478,0.637113
3,3,66.82243,0.605163
4,4,69.24384,0.587622
5,5,71.197961,0.562709
6,6,72.854715,0.542424
7,7,71.325404,0.607925
8,8,73.449448,0.546692
9,9,74.596432,0.527992


In [19]:
# with open('two_dfs.pkl', 'rb') as file:
#     loaded = pickle.load(file)

# loaded