In [1]:
#from train5; use cbam + 6. MGA
import torch,os 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms import v2
from torchvision import tv_tensors,utils
import torchvision.models as models
import pandas as pd
import matplotlib.pyplot as plt

from nets import ResNet18CBAMMask
from utils import progress_bar
from tqdm import tqdm
from configparser import ConfigParser
from torch.utils.data import  DataLoader
from LIDC_M_data import LIDC_Dataset

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

prep_tr = [
    v2.Lambda(lambda x: tv_tensors.Image(torch.clamp(x,-1000.,400.)) if isinstance(x, tv_tensors.Image) else x),
    v2.Lambda(lambda x: tv_tensors.Image((x+1000)/1400) if isinstance(x, tv_tensors.Image) else x),
    v2.CenterCrop((384,384)),
    # v2.Lambda(lambda x: x.expand(3,-1,-1))
]
aug_tr = [
    v2.RandomAffine(degrees=10),
    v2.RandomHorizontalFlip(),
    # v2.GaussianNoise(0,0.1)
]
trans_train = v2.Compose( prep_tr + aug_tr )
trans_test = v2.Compose( prep_tr  )

In [3]:

parser = ConfigParser()
parser.read('.settings')
root_dir = parser.get('dataset','root_dir') #/workspaces/data/lidc-idri/slices
meta_dir = parser.get('dataset','meta_dir') #/workspaces/data/lidc-idri/splits
train_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'trainBB_malB.csv'),transform=trans_train)
test_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'testBB_malB.csv'),transform=trans_test)
total_train_data = len(train_data)
total_test_data = len(test_data)
print('total_train_data:',total_train_data, 'total_test_data:',total_test_data)

trainloader = DataLoader(train_data, batch_size=16, shuffle=True)
testloader = DataLoader(test_data, batch_size=16)

total_train_data: 5495 total_test_data: 2353


In [4]:
net = ResNet18CBAMMask()
# net = models.resnet18(pretrained=True)
net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
net.fc = nn.Linear(net.fc.in_features, 2)
net = net.to(device)

In [None]:
from torchinfo import summary
summary(net, input_size=(16,1, 384, 384))

In [5]:
lr = 1e-4
criterion = nn.CrossEntropyLoss()
# criterion = nn.BCEWithLogitsLoss()
mse = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

trainning_accuracy=[]
trainning_loss=[]
testing_accuracy=[]
testing_loss=[]

In [None]:
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets, masks) in enumerate(tqdm(trainloader,desc=f"[Epoch {epoch+1}]")):
        inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)
        # targets = targets.float().unsqueeze(1)

        optimizer.zero_grad()
        outputs, attn_map = net(inputs)
        
        cls_loss = criterion(outputs, targets)
        # masks = torch.zeros(attn_map.shape)
        # masks = torch.zeros(attn_map.shape[:-2] + bboxes[0].canvas_size)
        # for i,bbi in enumerate(bboxes):
        #     for b in bbi:
        #         x1, y1, x2, y2 = b.round.int()
        #         masks[i, :attn_map.shape[1] , y1:y2+1, x1:x2+1] = 1.0
        # masks = masks.to(device)
        masks = F.adaptive_avg_pool2d(masks, attn_map.shape[-2:])
        mse_loss = mse(attn_map , masks)
        loss =  cls_loss + 0.1 * mse_loss

        loss.backward()
        
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        # preds = (torch.sigmoid(outputs) > 0.5).squeeze().long()
        
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        # correct += (preds == targets.long()).sum().item()

    train_acc = 100.*correct/total
    train_loss = train_loss/(batch_idx+1)
    print(f"Train Loss: {train_loss}, Train Acc: {train_acc:.2f}%")
    trainning_accuracy.append(train_acc)
    trainning_loss.append( train_loss )

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets,_) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            # targets = targets.float().unsqueeze(1)
            outputs, _ = net(inputs)
            
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            # preds = (torch.sigmoid(outputs) > 0.5).squeeze().long()
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            # correct += (preds == targets.long()).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
        test_acc = 100.*correct/total
        test_loss = test_loss/(batch_idx+1)
        print(f"Test Loss: {test_loss}, Test Acc: {test_acc:.2f}%")
        testing_accuracy.append(100.*correct/total)
        testing_loss.append(test_loss/(batch_idx+1))
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_owndata.pth')
        best_acc = acc

In [7]:
for epoch in range(start_epoch, start_epoch+50):
    train(epoch)
    test(epoch)
 
 #   scheduler.step()

[Epoch 1]: 100%|██████████| 344/344 [05:57<00:00,  1.04s/it]


Train Loss: 0.7048302358666132, Train Acc: 55.81%


TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not tuple

In [None]:
net

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): CBAMBasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (cbam): CBAM(
        (ca): ChannelAttention(
          (shared_mlp): Sequential(
            (0): Linear(in_features=64, out_features=4, bias=False)
            (1): ReLU()
            (2): Linear(in_features=4, out_features=64, bias=Fal

In [8]:
from torchinfo import summary
summary(net, input_size=(16,1, 384, 384))

Layer (type:depth-idx)                             Output Shape              Param #
ResNet18CBAMMask                                   [16, 2]                   --
├─Conv2d: 1-1                                      [16, 64, 192, 192]        3,136
├─BatchNorm2d: 1-2                                 [16, 64, 192, 192]        128
├─ReLU: 1-3                                        [16, 64, 192, 192]        --
├─MaxPool2d: 1-4                                   [16, 64, 96, 96]          --
├─Sequential: 1-5                                  [16, 64, 96, 96]          --
│    └─CBAMBasicBlock: 2-1                         [16, 64, 96, 96]          --
│    │    └─Conv2d: 3-1                            [16, 64, 96, 96]          36,864
│    │    └─BatchNorm2d: 3-2                       [16, 64, 96, 96]          128
│    │    └─ReLU: 3-3                              [16, 64, 96, 96]          --
│    │    └─Conv2d: 3-4                            [16, 64, 96, 96]          36,864
│    │    └─BatchNorm2

In [None]:
# net0 = models.resnet18(pretrained=True)
# net0.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# net0.fc = nn.Linear(net0.fc.in_features, 2)
# net0 = net0.to(device)
# summary(net0, input_size=(16,1, 384, 384))

In [None]:
net.linear = nn.Linear(73728,2)

In [None]:
net.linear

Linear(in_features=73728, out_features=2, bias=True)