In [1]:
#train3 from 2: use adam, remove weight_decay; one BW channel (not 3)
import torch,os,glob
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision.transforms import v2
from torchvision import tv_tensors
import pandas as pd
import numpy as np

from nets import ResNet18_lbCM
from tqdm import tqdm
from configparser import ConfigParser
from torch.utils.data import  DataLoader
from LIDC_Mpad_data import LIDC_Dataset
import pickle

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
best_epoch = 0
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

prep_tr = [
    v2.Lambda(lambda x: tv_tensors.Image(torch.clamp(x,-1000.,400.)) if isinstance(x, tv_tensors.Image) else x),
    v2.Lambda(lambda x: tv_tensors.Image((x+1000)/1400) if isinstance(x, tv_tensors.Image) else x),
    v2.CenterCrop((384,384)),
]
aug_tr = [
    v2.RandomAffine(degrees=10),
    v2.RandomHorizontalFlip(),
]
trans_train = v2.Compose( prep_tr + aug_tr )
trans_test = v2.Compose( prep_tr  )

In [3]:

parser = ConfigParser()
parser.read('.settings')
root_dir = parser.get('dataset','root_dir') #/workspaces/data/lidc-idri/slices
meta_dir = parser.get('dataset','meta_dir') #/workspaces/data/lidc-idri/splits
result_dir = os.path.join(parser.get('dataset','result_dir'),'stage2/basel0_lb4MGA-tune')
if not os.path.isdir(result_dir):
        os.mkdir(result_dir)

train_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'trainBB_malB.csv'),transform=trans_train, loadBB=True)
test_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'testBB_malB.csv'),transform=trans_test)
total_train_data = len(train_data)
total_test_data = len(test_data)
print('total_train_data:',total_train_data, 'total_test_data:',total_test_data)

batch_size = int(parser['dataset']['batch_size'])
trainloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
testloader = DataLoader(test_data, batch_size=batch_size, num_workers=8)

total_train_data: 5495 total_test_data: 2354


In [4]:
net = ResNet18_lbCM(pretrained=True,attr="LA4_BL2")
net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.fc = nn.Linear(net.fc.in_features, 2)
net = net.to(device)

In [5]:
from torchinfo import summary
summary(net, input_size=(batch_size,1, 384, 384))

Layer (type:depth-idx)                             Output Shape              Param #
ResNet18_lbCM                                      [32, 2]                   --
├─Conv2d: 1-1                                      [32, 64, 192, 192]        3,136
├─BatchNorm2d: 1-2                                 [32, 64, 192, 192]        128
├─ReLU: 1-3                                        [32, 64, 192, 192]        --
├─MaxPool2d: 1-4                                   [32, 64, 96, 96]          --
├─Sequential: 1-5                                  [32, 64, 96, 96]          --
│    └─BasicBlock: 2-1                             [32, 64, 96, 96]          --
│    │    └─Conv2d: 3-1                            [32, 64, 96, 96]          36,864
│    │    └─BatchNorm2d: 3-2                       [32, 64, 96, 96]          128
│    │    └─ReLU: 3-3                              [32, 64, 96, 96]          --
│    │    └─Conv2d: 3-4                            [32, 64, 96, 96]          36,864
│    │    └─BatchNorm2

In [6]:
lr = 1e-4
criterion = nn.CrossEntropyLoss()
mse = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

training_info=[["epoch","acc","loss"]]
testing_info=training_info.copy()

In [7]:
def train(epoch):
    net.train()
    train_loss = np.zeros(3)
    correct = 0
    total = 0
    pbar = tqdm(trainloader)
    for batch_idx, (inputs, targets, masks) in enumerate(pbar):
        inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)

        optimizer.zero_grad()
        # outputs = net(inputs)
        outputs, attn_map = net(inputs)
        cls_loss = criterion(outputs, targets)
        masks = F.adaptive_avg_pool2d(masks, attn_map.shape[-2:])
        att_loss = 10.* mse(attn_map , masks)
        
        loss =  cls_loss + att_loss
        
        # loss = criterion(outputs, targets)
        loss.backward()
        
        optimizer.step()

        train_loss += np.array([loss.item(), cls_loss.item(), att_loss.item()])
        _, predicted = outputs.max(1)
        
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(f"Epoch: {epoch} Acc: {(100.*correct/total):.2f}")

    train_acc = 100.*correct/total
    train_loss = train_loss/(batch_idx+1)
    print(f"Tot Loss: {train_loss[0]:.4f} CL: {train_loss[1]:.5f} AT: {train_loss[2]:.5f}; Train Acc: {train_acc:.2f}%")
    training_info.append([epoch,train_acc,train_loss])
    # trainning_accuracy.append(train_acc)
    # trainning_loss.append( train_loss )

def test(epoch, islast = False):
    global best_acc, best_epoch
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(testloader)
    with torch.no_grad():
        for batch_idx, (inputs, targets ) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        test_acc = 100.*correct/total
        test_loss = test_loss/(batch_idx+1)
        print(f"Test Loss: {test_loss}, Test Acc: {test_acc:.2f}%")
        testing_info.append([epoch,test_acc,test_loss])
        # testing_accuracy.append(test_acc)
        # testing_loss.append(test_loss)
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc or islast:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        savestr = 'best' if acc > best_acc else 'last'
        torch.save(state, os.path.join(result_dir,f'basel0-b{batch_size}-epoch{epoch}-{savestr}.pth'))
        best_acc = acc
        best_epoch = epoch

In [8]:
aa = np.zeros(3)
aa = [1,2,3]

In [9]:
print(aa)

[1, 2, 3]


In [None]:
if start_epoch>0:
    checkpoint = torch.load(glob.glob(os.path.join(result_dir,f'basel0-b{batch_size}-epoch{start_epoch-1}-*.pth'))[0] )
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']

for epoch in range(start_epoch, start_epoch+50):
    train(epoch)
    test(epoch, islast = epoch==start_epoch+49)

Epoch: 0 Acc: 56.03: 100%|██████████| 172/172 [00:36<00:00,  4.70it/s]


Tot Loss: 0.9554 CL: 0.68296 AT: 0.27246; Train Acc: 56.03%


100%|██████████| 74/74 [00:05<00:00, 14.80it/s]


Test Loss: 0.6628181692716237, Test Acc: 60.28%
Saving..


Epoch: 1 Acc: 63.51: 100%|██████████| 172/172 [00:39<00:00,  4.35it/s]


Tot Loss: 0.8599 CL: 0.64284 AT: 0.21707; Train Acc: 63.51%


100%|██████████| 74/74 [00:05<00:00, 13.93it/s]


Test Loss: 0.6072840742968224, Test Acc: 67.84%
Saving..


Epoch: 2 Acc: 69.10: 100%|██████████| 172/172 [00:38<00:00,  4.41it/s]


Tot Loss: 0.7801 CL: 0.58961 AT: 0.19047; Train Acc: 69.10%


100%|██████████| 74/74 [00:05<00:00, 13.52it/s]


Test Loss: 0.5747128414946634, Test Acc: 71.20%
Saving..


Epoch: 3 Acc: 71.50: 100%|██████████| 172/172 [00:39<00:00,  4.32it/s]


Tot Loss: 0.7312 CL: 0.55506 AT: 0.17612; Train Acc: 71.50%


100%|██████████| 74/74 [00:07<00:00,  9.31it/s]


Test Loss: 0.5549086267883713, Test Acc: 73.15%
Saving..


Epoch: 4 Acc: 74.30: 100%|██████████| 172/172 [00:38<00:00,  4.44it/s]


Tot Loss: 0.7006 CL: 0.53122 AT: 0.16939; Train Acc: 74.30%


100%|██████████| 74/74 [00:05<00:00, 13.97it/s]


Test Loss: 0.5339989392338572, Test Acc: 74.38%
Saving..


Epoch: 5 Acc: 76.27: 100%|██████████| 172/172 [00:37<00:00,  4.55it/s]


Tot Loss: 0.6606 CL: 0.49542 AT: 0.16516; Train Acc: 76.27%


100%|██████████| 74/74 [00:05<00:00, 12.48it/s]


Test Loss: 0.506007073296083, Test Acc: 75.70%
Saving..


Epoch: 6 Acc: 78.40: 100%|██████████| 172/172 [00:37<00:00,  4.61it/s]


Tot Loss: 0.6253 CL: 0.46449 AT: 0.16079; Train Acc: 78.40%


100%|██████████| 74/74 [00:08<00:00,  8.97it/s]


Test Loss: 0.5628580302000046, Test Acc: 72.39%


Epoch: 7 Acc: 80.47: 100%|██████████| 172/172 [00:37<00:00,  4.57it/s]


Tot Loss: 0.5980 CL: 0.43783 AT: 0.16014; Train Acc: 80.47%


100%|██████████| 74/74 [00:05<00:00, 14.34it/s]


Test Loss: 0.4986705691427798, Test Acc: 76.51%
Saving..


Epoch: 8 Acc: 81.95: 100%|██████████| 172/172 [00:37<00:00,  4.63it/s]


Tot Loss: 0.5611 CL: 0.40694 AT: 0.15415; Train Acc: 81.95%


100%|██████████| 74/74 [00:05<00:00, 14.14it/s]


Test Loss: 0.45858245282559784, Test Acc: 78.93%
Saving..


Epoch: 9 Acc: 83.80: 100%|██████████| 172/172 [00:37<00:00,  4.57it/s]


Tot Loss: 0.5204 CL: 0.36679 AT: 0.15362; Train Acc: 83.80%


100%|██████████| 74/74 [00:08<00:00,  9.02it/s]


Test Loss: 0.4423719249867104, Test Acc: 80.12%
Saving..


Epoch: 10 Acc: 85.53: 100%|██████████| 172/172 [00:36<00:00,  4.72it/s]


Tot Loss: 0.4915 CL: 0.33910 AT: 0.15244; Train Acc: 85.53%


100%|██████████| 74/74 [00:05<00:00, 12.67it/s]


Test Loss: 0.44048896150009054, Test Acc: 80.50%
Saving..


Epoch: 11 Acc: 86.97: 100%|██████████| 172/172 [00:37<00:00,  4.64it/s]


Tot Loss: 0.4554 CL: 0.30477 AT: 0.15066; Train Acc: 86.97%


100%|██████████| 74/74 [00:05<00:00, 13.76it/s]


Test Loss: 0.4229389060590718, Test Acc: 80.80%
Saving..


Epoch: 12 Acc: 88.57: 100%|██████████| 172/172 [00:36<00:00,  4.67it/s]


Tot Loss: 0.4277 CL: 0.27798 AT: 0.14974; Train Acc: 88.57%


100%|██████████| 74/74 [00:08<00:00,  8.43it/s]


Test Loss: 0.5459978274396948, Test Acc: 75.79%


Epoch: 13 Acc: 89.41: 100%|██████████| 172/172 [00:37<00:00,  4.60it/s]


Tot Loss: 0.4086 CL: 0.26077 AT: 0.14787; Train Acc: 89.41%


100%|██████████| 74/74 [00:07<00:00, 10.20it/s]


Test Loss: 0.44098464621079936, Test Acc: 81.22%
Saving..


Epoch: 14 Acc: 90.37: 100%|██████████| 172/172 [00:37<00:00,  4.64it/s]


Tot Loss: 0.3843 CL: 0.23772 AT: 0.14658; Train Acc: 90.37%


100%|██████████| 74/74 [00:05<00:00, 12.74it/s]


Test Loss: 0.500465494152662, Test Acc: 80.03%


Epoch: 15 Acc: 91.26: 100%|██████████| 172/172 [00:41<00:00,  4.11it/s]


Tot Loss: 0.3693 CL: 0.22508 AT: 0.14418; Train Acc: 91.26%


100%|██████████| 74/74 [00:05<00:00, 14.16it/s]


Test Loss: 0.5137602384831454, Test Acc: 77.95%


Epoch: 16 Acc: 91.83: 100%|██████████| 172/172 [00:38<00:00,  4.41it/s]


Tot Loss: 0.3402 CL: 0.19853 AT: 0.14167; Train Acc: 91.83%


100%|██████████| 74/74 [00:06<00:00, 11.56it/s]


Test Loss: 0.41689552195571566, Test Acc: 82.29%
Saving..


Epoch: 17 Acc: 93.14: 100%|██████████| 172/172 [00:46<00:00,  3.68it/s]


Tot Loss: 0.3216 CL: 0.18123 AT: 0.14041; Train Acc: 93.14%


100%|██████████| 74/74 [00:07<00:00, 10.18it/s]


Test Loss: 0.4713629484176636, Test Acc: 80.54%


Epoch: 18 Acc: 93.33:   9%|▊         | 15/172 [00:04<00:32,  4.87it/s]

In [None]:
traindf = pd.DataFrame(training_info[1:],columns=training_info[0])
testdf = pd.DataFrame(testing_info[1:],columns=testing_info[0])
with open(os.path.join(result_dir,f'basel0-b{batch_size}-info.pkl'), 'wb') as file:
    pickle.dump({"train":traindf,"test":testdf}, file)

 
 #   scheduler.step()

In [None]:
import re
levels = re.findall(r'MGA(\d+)','MGA22')
mga_level = int(levels[0]) if levels else 3
# 'MGA3'.startswith('MGA')
mga_level

22

In [None]:
testdf

Unnamed: 0,epoch,acc,loss
0,0,62.149533,0.649075
1,1,66.482583,0.609698
2,2,71.750212,0.566307
3,3,72.982158,0.542833
4,4,73.831776,0.516786
5,5,75.615973,0.516502
6,6,74.171623,0.512959
7,7,71.070518,0.566263
8,8,79.864061,0.427543
9,9,79.099405,0.479788


In [None]:
# with open('two_dfs.pkl', 'rb') as file:
#     loaded = pickle.load(file)

# loaded