In [1]:
#train3 from 2: use adam, remove weight_decay; one BW channel (not 3)
import torch,os,glob
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision.transforms import v2
from torchvision import tv_tensors
import pandas as pd

from nets import ResNet18_3lbCBAM
from tqdm import tqdm
from configparser import ConfigParser
from torch.utils.data import  DataLoader
from LIDC_M_data import LIDC_Dataset
import pickle

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
best_epoch = 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

prep_tr = [
    v2.Lambda(lambda x: tv_tensors.Image(torch.clamp(x,-1000.,400.)) if isinstance(x, tv_tensors.Image) else x),
    v2.Lambda(lambda x: tv_tensors.Image((x+1000)/1400) if isinstance(x, tv_tensors.Image) else x),
    v2.CenterCrop((384,384)),
]
aug_tr = [
    v2.RandomAffine(degrees=10),
    v2.RandomHorizontalFlip(),
]
trans_train = v2.Compose( prep_tr + aug_tr )
trans_test = v2.Compose( prep_tr  )

In [3]:

parser = ConfigParser()
parser.read('.settings')
root_dir = parser.get('dataset','root_dir') #/workspaces/data/lidc-idri/slices
meta_dir = parser.get('dataset','meta_dir') #/workspaces/data/lidc-idri/splits
result_dir = os.path.join(parser.get('dataset','result_dir'),'stage2/basel0_3lbMGA')
if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

train_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'trainBB_malB.csv'),transform=trans_train, loadBB=True)
test_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'testBB_malB.csv'),transform=trans_test)
total_train_data = len(train_data)
total_test_data = len(test_data)
print('total_train_data:',total_train_data, 'total_test_data:',total_test_data)

batch_size = int(parser['dataset']['batch_size'])
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
testloader = DataLoader(test_data, batch_size=batch_size, num_workers=8)

total_train_data: 5495 total_test_data: 2354


In [4]:
net = ResNet18_3lbCBAM(pretrained=True,attr="MGA")
net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.fc = nn.Linear(net.fc.in_features, 2)
net = net.to(device)

In [5]:
from torchinfo import summary
summary(net, input_size=(batch_size,1, 384, 384))

Layer (type:depth-idx)                             Output Shape              Param #
ResNet18_3lbCBAM                                   [32, 2]                   --
├─Conv2d: 1-1                                      [32, 64, 192, 192]        3,136
├─BatchNorm2d: 1-2                                 [32, 64, 192, 192]        128
├─ReLU: 1-3                                        [32, 64, 192, 192]        --
├─MaxPool2d: 1-4                                   [32, 64, 96, 96]          --
├─Sequential: 1-5                                  [32, 64, 96, 96]          --
│    └─BasicBlock: 2-1                             [32, 64, 96, 96]          --
│    │    └─Conv2d: 3-1                            [32, 64, 96, 96]          36,864
│    │    └─BatchNorm2d: 3-2                       [32, 64, 96, 96]          128
│    │    └─ReLU: 3-3                              [32, 64, 96, 96]          --
│    │    └─Conv2d: 3-4                            [32, 64, 96, 96]          36,864
│    │    └─BatchNorm2

In [6]:
lr = 1e-4
criterion = nn.CrossEntropyLoss()
mse = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

training_info=[["epoch","acc","loss"]]
testing_info=training_info.copy()

In [7]:
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(trainloader)
    for batch_idx, (inputs, targets, masks) in enumerate(pbar):
        inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)

        optimizer.zero_grad()
        # outputs = net(inputs)
        outputs, attn_map = net(inputs)
        cls_loss = criterion(outputs, targets)
        masks = F.adaptive_avg_pool2d(masks, attn_map.shape[-2:])
        mse_loss = mse(attn_map , masks)
        loss =  cls_loss + 0.1 * mse_loss
        
        # loss = criterion(outputs, targets)
        loss.backward()
        
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(f"Epoch: {epoch} Acc: {(100.*correct/total):.2f}")

    train_acc = 100.*correct/total
    train_loss = train_loss/(batch_idx+1)
    print(f"Train Loss: {train_loss}, Train Acc: {train_acc:.2f}%")
    training_info.append([epoch,train_acc,train_loss])
    # trainning_accuracy.append(train_acc)
    # trainning_loss.append( train_loss )

def test(epoch, islast = False):
    global best_acc, best_epoch
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(testloader)
    with torch.no_grad():
        for batch_idx, (inputs, targets ) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        test_acc = 100.*correct/total
        test_loss = test_loss/(batch_idx+1)
        print(f"Test Loss: {test_loss}, Test Acc: {test_acc:.2f}%")
        testing_info.append([epoch,test_acc,test_loss])
        # testing_accuracy.append(test_acc)
        # testing_loss.append(test_loss)
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc or islast:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        savestr = 'best' if acc > best_acc else 'last'
        torch.save(state, os.path.join(result_dir,f'basel0_3lbMGA-b{batch_size}-epoch{epoch}-{savestr}.pth'))
        best_acc = acc
        best_epoch = epoch

In [8]:
if start_epoch>0:
    checkpoint = torch.load(glob.glob(os.path.join(result_dir,f'basel0_3lbMGA-b{batch_size}-epoch{start_epoch-1}-*.pth'))[0] )
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']

for epoch in range(start_epoch, start_epoch+50):
    train(epoch)
    test(epoch, islast = epoch==start_epoch+49)

Epoch: 0 Acc: 59.25: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.6962568021790926, Train Acc: 59.25%


100%|██████████| 74/74 [00:05<00:00, 13.05it/s]


Test Loss: 0.6415471958147513, Test Acc: 61.85%
Saving..


Epoch: 1 Acc: 64.93: 100%|██████████| 172/172 [00:44<00:00,  3.88it/s]


Train Loss: 0.6459493472479111, Train Acc: 64.93%


100%|██████████| 74/74 [00:05<00:00, 12.51it/s]


Test Loss: 0.6243702869963002, Test Acc: 64.53%
Saving..


Epoch: 2 Acc: 68.04: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.6054607068383416, Train Acc: 68.04%


100%|██████████| 74/74 [00:05<00:00, 12.97it/s]


Test Loss: 0.5887590316501824, Test Acc: 69.20%
Saving..


Epoch: 3 Acc: 72.08: 100%|██████████| 172/172 [00:44<00:00,  3.88it/s]


Train Loss: 0.5486847153583239, Train Acc: 72.08%


100%|██████████| 74/74 [00:05<00:00, 12.86it/s]


Test Loss: 0.6466946851562809, Test Acc: 65.93%


Epoch: 4 Acc: 75.41: 100%|██████████| 172/172 [00:41<00:00,  4.17it/s]


Train Loss: 0.5075082524224769, Train Acc: 75.41%


100%|██████████| 74/74 [00:05<00:00, 13.03it/s]


Test Loss: 0.5359337990348404, Test Acc: 75.28%
Saving..


Epoch: 5 Acc: 77.89: 100%|██████████| 172/172 [00:44<00:00,  3.89it/s]


Train Loss: 0.46555518826773, Train Acc: 77.89%


100%|██████████| 74/74 [00:05<00:00, 12.96it/s]


Test Loss: 0.6553705540057775, Test Acc: 67.63%


Epoch: 6 Acc: 80.60: 100%|██████████| 172/172 [00:41<00:00,  4.18it/s]


Train Loss: 0.42855265095483425, Train Acc: 80.60%


100%|██████████| 74/74 [00:05<00:00, 13.09it/s]


Test Loss: 0.5848832762724644, Test Acc: 72.56%


Epoch: 7 Acc: 82.22: 100%|██████████| 172/172 [00:44<00:00,  3.86it/s]


Train Loss: 0.39675416641457134, Train Acc: 82.22%


100%|██████████| 74/74 [00:05<00:00, 12.91it/s]


Test Loss: 0.4741110193568307, Test Acc: 78.29%
Saving..


Epoch: 8 Acc: 85.00: 100%|██████████| 172/172 [00:41<00:00,  4.17it/s]


Train Loss: 0.3556605016942634, Train Acc: 85.00%


100%|██████████| 74/74 [00:05<00:00, 12.97it/s]


Test Loss: 0.48581234144197927, Test Acc: 77.87%


Epoch: 9 Acc: 85.68: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.3255203660404266, Train Acc: 85.68%


100%|██████████| 74/74 [00:08<00:00,  8.32it/s]


Test Loss: 0.4758149480094781, Test Acc: 78.89%
Saving..


Epoch: 10 Acc: 87.44: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.3040143227594536, Train Acc: 87.44%


100%|██████████| 74/74 [00:05<00:00, 12.99it/s]


Test Loss: 0.4584590194595827, Test Acc: 80.54%
Saving..


Epoch: 11 Acc: 88.74: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.27119057989397716, Train Acc: 88.74%


100%|██████████| 74/74 [00:08<00:00,  8.35it/s]


Test Loss: 0.5114971779890962, Test Acc: 78.33%


Epoch: 12 Acc: 89.39: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.26181365558227826, Train Acc: 89.39%


100%|██████████| 74/74 [00:05<00:00, 13.03it/s]


Test Loss: 0.4697663314841889, Test Acc: 80.93%
Saving..


Epoch: 13 Acc: 90.77: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.2343534280784255, Train Acc: 90.77%


100%|██████████| 74/74 [00:08<00:00,  8.47it/s]


Test Loss: 0.672826972664208, Test Acc: 76.81%


Epoch: 14 Acc: 91.79: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.21429785655075034, Train Acc: 91.79%


100%|██████████| 74/74 [00:05<00:00, 12.94it/s]


Test Loss: 0.5096331863991312, Test Acc: 81.05%
Saving..


Epoch: 15 Acc: 92.58: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.19925837530646212, Train Acc: 92.58%


100%|██████████| 74/74 [00:07<00:00, 10.28it/s]


Test Loss: 0.4390257782227284, Test Acc: 83.09%
Saving..


Epoch: 16 Acc: 92.63: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.18689408287579237, Train Acc: 92.63%


100%|██████████| 74/74 [00:05<00:00, 12.97it/s]


Test Loss: 0.4810330434828191, Test Acc: 81.39%


Epoch: 17 Acc: 93.76: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.17081290381679007, Train Acc: 93.76%


100%|██████████| 74/74 [00:05<00:00, 12.77it/s]


Test Loss: 0.49177900611146075, Test Acc: 83.14%
Saving..


Epoch: 18 Acc: 93.27: 100%|██████████| 172/172 [00:44<00:00,  3.87it/s]


Train Loss: 0.17416348210860824, Train Acc: 93.27%


100%|██████████| 74/74 [00:05<00:00, 12.77it/s]


Test Loss: 0.4567030265524581, Test Acc: 83.52%
Saving..


Epoch: 19 Acc: 94.18: 100%|██████████| 172/172 [00:41<00:00,  4.17it/s]


Train Loss: 0.16586869548962907, Train Acc: 94.18%


100%|██████████| 74/74 [00:05<00:00, 12.79it/s]


Test Loss: 0.468788990297833, Test Acc: 82.41%


Epoch: 20 Acc: 94.23: 100%|██████████| 172/172 [00:44<00:00,  3.86it/s]


Train Loss: 0.16153899679870107, Train Acc: 94.23%


100%|██████████| 74/74 [00:05<00:00, 12.78it/s]


Test Loss: 0.7951070136717848, Test Acc: 75.11%


Epoch: 21 Acc: 94.94: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.14087544537560884, Train Acc: 94.94%


100%|██████████| 74/74 [00:05<00:00, 12.82it/s]


Test Loss: 0.4937628541846533, Test Acc: 83.56%
Saving..


Epoch: 22 Acc: 94.14: 100%|██████████| 172/172 [00:44<00:00,  3.88it/s]


Train Loss: 0.14930623790335862, Train Acc: 94.14%


100%|██████████| 74/74 [00:05<00:00, 12.80it/s]


Test Loss: 0.4803539943453428, Test Acc: 83.60%
Saving..


Epoch: 23 Acc: 95.23: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.1280974182571003, Train Acc: 95.23%


100%|██████████| 74/74 [00:05<00:00, 12.93it/s]


Test Loss: 0.4471635239551196, Test Acc: 84.88%
Saving..


Epoch: 24 Acc: 94.89: 100%|██████████| 172/172 [00:44<00:00,  3.87it/s]


Train Loss: 0.13377536394698328, Train Acc: 94.89%


100%|██████████| 74/74 [00:05<00:00, 12.82it/s]


Test Loss: 0.47373495573127594, Test Acc: 85.05%
Saving..


Epoch: 25 Acc: 95.60: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.11852288525551558, Train Acc: 95.60%


100%|██████████| 74/74 [00:05<00:00, 12.87it/s]


Test Loss: 0.5712896913692758, Test Acc: 82.46%


Epoch: 26 Acc: 95.40: 100%|██████████| 172/172 [00:42<00:00,  4.02it/s]


Train Loss: 0.1232130590956225, Train Acc: 95.40%


100%|██████████| 74/74 [00:07<00:00,  9.93it/s]


Test Loss: 0.6608014876979429, Test Acc: 82.63%


Epoch: 27 Acc: 95.98: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.11264735672536284, Train Acc: 95.98%


100%|██████████| 74/74 [00:05<00:00, 12.85it/s]


Test Loss: 0.5108576331388306, Test Acc: 84.62%


Epoch: 28 Acc: 96.16: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.10631446121531257, Train Acc: 96.16%


100%|██████████| 74/74 [00:08<00:00,  8.29it/s]


Test Loss: 0.6246718877070659, Test Acc: 81.39%


Epoch: 29 Acc: 96.60: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.09088916758746775, Train Acc: 96.60%


100%|██████████| 74/74 [00:05<00:00, 12.78it/s]


Test Loss: 0.5531064441075196, Test Acc: 84.32%


Epoch: 30 Acc: 96.32: 100%|██████████| 172/172 [00:41<00:00,  4.17it/s]


Train Loss: 0.10346636454423153, Train Acc: 96.32%


100%|██████████| 74/74 [00:08<00:00,  8.33it/s]


Test Loss: 0.5524122631680723, Test Acc: 83.52%


Epoch: 31 Acc: 96.47: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.09806684109512283, Train Acc: 96.47%


100%|██████████| 74/74 [00:05<00:00, 12.91it/s]


Test Loss: 0.5364966593884133, Test Acc: 85.73%
Saving..


Epoch: 32 Acc: 96.34: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.09633631798774438, Train Acc: 96.34%


100%|██████████| 74/74 [00:08<00:00,  8.47it/s]


Test Loss: 0.519954584236886, Test Acc: 84.49%


Epoch: 33 Acc: 97.20: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.08129017833663627, Train Acc: 97.20%


100%|██████████| 74/74 [00:05<00:00, 12.83it/s]


Test Loss: 0.6285955258519262, Test Acc: 82.37%


Epoch: 34 Acc: 97.45: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.07051895596195272, Train Acc: 97.45%


100%|██████████| 74/74 [00:05<00:00, 12.90it/s]


Test Loss: 0.5821773167196158, Test Acc: 84.07%


Epoch: 35 Acc: 96.62: 100%|██████████| 172/172 [00:44<00:00,  3.87it/s]


Train Loss: 0.09245916588372703, Train Acc: 96.62%


100%|██████████| 74/74 [00:05<00:00, 12.98it/s]


Test Loss: 0.5618621523718577, Test Acc: 84.54%


Epoch: 36 Acc: 97.07: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.08027627808312604, Train Acc: 97.07%


100%|██████████| 74/74 [00:05<00:00, 12.82it/s]


Test Loss: 0.5764763428754097, Test Acc: 84.62%


Epoch: 37 Acc: 96.71: 100%|██████████| 172/172 [00:44<00:00,  3.87it/s]


Train Loss: 0.08489202727904803, Train Acc: 96.71%


100%|██████████| 74/74 [00:05<00:00, 12.75it/s]


Test Loss: 0.7204402813536895, Test Acc: 83.31%


Epoch: 38 Acc: 97.51: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.07169548021954332, Train Acc: 97.51%


100%|██████████| 74/74 [00:05<00:00, 12.89it/s]


Test Loss: 0.6761362040667115, Test Acc: 82.88%


Epoch: 39 Acc: 97.93: 100%|██████████| 172/172 [00:44<00:00,  3.86it/s]


Train Loss: 0.06222248010656818, Train Acc: 97.93%


100%|██████████| 74/74 [00:05<00:00, 12.80it/s]


Test Loss: 0.5833410728924178, Test Acc: 85.85%
Saving..


Epoch: 40 Acc: 97.73: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.06499437407758303, Train Acc: 97.73%


100%|██████████| 74/74 [00:05<00:00, 12.85it/s]


Test Loss: 0.5627053308527212, Test Acc: 86.41%
Saving..


Epoch: 41 Acc: 97.53: 100%|██████████| 172/172 [00:44<00:00,  3.86it/s]


Train Loss: 0.07493736007965582, Train Acc: 97.53%


100%|██████████| 74/74 [00:05<00:00, 12.76it/s]


Test Loss: 0.5896368468734058, Test Acc: 85.47%


Epoch: 42 Acc: 97.62: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.07131057868300136, Train Acc: 97.62%


100%|██████████| 74/74 [00:05<00:00, 12.88it/s]


Test Loss: 0.7299303672704343, Test Acc: 81.61%


Epoch: 43 Acc: 97.60: 100%|██████████| 172/172 [00:44<00:00,  3.88it/s]


Train Loss: 0.06491327959621801, Train Acc: 97.60%


100%|██████████| 74/74 [00:05<00:00, 12.75it/s]


Test Loss: 0.6270744426024927, Test Acc: 85.30%


Epoch: 44 Acc: 97.69: 100%|██████████| 172/172 [00:41<00:00,  4.13it/s]


Train Loss: 0.0613944934136884, Train Acc: 97.69%


100%|██████████| 74/74 [00:05<00:00, 12.64it/s]


Test Loss: 0.8246795376007622, Test Acc: 80.80%


Epoch: 45 Acc: 97.71: 100%|██████████| 172/172 [00:41<00:00,  4.16it/s]


Train Loss: 0.059828196580489275, Train Acc: 97.71%


100%|██████████| 74/74 [00:08<00:00,  8.36it/s]


Test Loss: 0.6307680647618867, Test Acc: 85.68%


Epoch: 46 Acc: 97.60: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.06696144766634503, Train Acc: 97.60%


100%|██████████| 74/74 [00:05<00:00, 12.79it/s]


Test Loss: 0.5628780935563751, Test Acc: 86.28%


Epoch: 47 Acc: 97.51: 100%|██████████| 172/172 [00:41<00:00,  4.14it/s]


Train Loss: 0.0692619739051605, Train Acc: 97.51%


100%|██████████| 74/74 [00:08<00:00,  8.42it/s]


Test Loss: 0.6608886540540166, Test Acc: 84.58%


Epoch: 48 Acc: 98.16: 100%|██████████| 172/172 [00:41<00:00,  4.14it/s]


Train Loss: 0.0520452436162592, Train Acc: 98.16%


100%|██████████| 74/74 [00:05<00:00, 12.69it/s]


Test Loss: 0.6506119582399323, Test Acc: 84.88%


Epoch: 49 Acc: 96.96: 100%|██████████| 172/172 [00:41<00:00,  4.15it/s]


Train Loss: 0.07710483850543062, Train Acc: 96.96%


100%|██████████| 74/74 [00:08<00:00,  8.30it/s]


Test Loss: 0.5859726953345377, Test Acc: 85.47%
Saving..


In [9]:
traindf = pd.DataFrame(training_info[1:],columns=training_info[0])
testdf = pd.DataFrame(testing_info[1:],columns=testing_info[0])
with open(os.path.join(result_dir,f'basel0_3lbCBAM-b{batch_size}-info.pkl'), 'wb') as file:
    pickle.dump({"train":traindf,"test":testdf}, file)

 
 #   scheduler.step()

In [10]:
testdf

Unnamed: 0,epoch,acc,loss
0,0,61.852167,0.641547
1,1,64.528462,0.62437
2,2,69.201359,0.588759
3,3,65.930331,0.646695
4,4,75.276126,0.535934
5,5,67.629567,0.655371
6,6,72.557349,0.584883
7,7,78.292268,0.474111
8,8,77.86746,0.485812
9,9,78.887001,0.475815


In [11]:
# with open('two_dfs.pkl', 'rb') as file:
#     loaded = pickle.load(file)

# loaded