In [None]:
import torch,os 
import torch.nn as nn
import torch.optim as optim

from torchvision.transforms import v2
import pandas as pd
import torchio as tio
import matplotlib.pyplot as plt

from models import *
from utils import progress_bar
from torch.utils.data import  DataLoader
from LIDC_data import LIDC_Dataset

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
prep_tr = [
    v2.Lambda(lambda x: torch.clamp(x,-1000,400)),
    v2.Lambda(lambda x: (x+1000)/1400),
    v2.CenterCrop((384,384))
]
aug_tr = [
    v2.RandomAffine(degrees=15),
    v2.RandomHorizontalFlip(),
    v2.GaussianNoise(0,0.1)
]
trans_train = v2.Compose( prep_tr + aug_tr )
trans_test = v2.Compose( prep_tr  )

In [None]:
root_dir = '/data1/lidc-idri/slices'
meta_dir = '/data2/lijin/lidc-prep/kjs/splits'
train_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'train_malB.csv'),transform=trans_train)
test_data = LIDC_Dataset(root_dir,metapath=os.path.join(meta_dir,'test_malB.csv'),transform=trans_test)
total_train_data = len(train_data)
total_test_data = len(test_data)
print('total_train_data:',total_train_data, 'total_test_data:',total_test_data)

trainloader = DataLoader(train_data, batch_size=16, shuffle=True)
testloader = DataLoader(test_data, batch_size=16)

In [None]:
net = ResNet18()
net = net.to(device)

In [None]:
lr = 1e-4
weight_decay = 1e-5
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(net.parameters(), lr=lr,  weight_decay=weight_decay)

trainning_accuracy=[]
trainning_loss=[]
testing_accuracy=[]
testing_loss=[]

In [None]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        
        loss = criterion(outputs, targets)
        loss.backward()
        
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    trainning_accuracy.append(100.*correct/total)
    trainning_loss.append( train_loss/(batch_idx+1))

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
        testing_accuracy.append(100.*correct/total)
        testing_loss.append(test_loss/(batch_idx+1))
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_owndata.pth')
        best_acc = acc