In [1]:
import cv2
import sys
import torch
import random
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
sys.path.append("../")
from abus_classification.datasets import TDSCTumors
from abus_classification.models import MultiTaskSegmentationClassificationABUS3D

In [11]:
class TransposeTransformer:
    
    def __init__(self, shape):
        self.shape = shape
        
    def __call__(self, *inputs):
        
        res = []
        
        for i in inputs:
            res.append(np.transpose(i, self.shape))
            
        return tuple(res) if len(res) > 1 else res[0]
                
class ResizeTransformer:
    
    def __init__(self, size):
        self.size = size
        
    def __call__(self, *inputs):
        outputs = []
        for data in inputs:
            resized_data = np.zeros((len(data),*self.size), dtype=np.float32)
            for idx, sli in enumerate(data):
                resized_data[idx] = cv2.resize(sli, self.size)
            outputs.append(resized_data)
        return tuple(outputs) if len(outputs) > 1 else outputs[0]

class ToTensorTransformer:    
    def __call__(self, data, mask):
        return torch.from_numpy(data), torch.from_numpy(mask)


tumors = TDSCTumors(path="../data/tdsc", transforms=[
    TransposeTransformer((1,0,2)), 
    ResizeTransformer((128,128)), 
    ToTensorTransformer()
    ])        

In [12]:
num_epochs = 100
learning_rate = 10e-3
alpha = .5
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MultiTaskSegmentationClassificationABUS3D().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCELoss()

In [15]:
def train(model, train_dataset):
    train_acc = 0
    loop = tqdm(range(num_epochs))
    for epoch in loop:
        for idx in train_dataset:
            volume, mask, label = tumors[idx]
            for i in range(0,len(volume), 10):
                x = volume[i:i+64] / 255
                m = mask[i:i+64]
                end = i + 64
                if end > len(volume):
                    len_ext = end - len(volume)
                    ext_tensor = torch.zeros([len_ext, 128, 128], dtype=torch.float32)
                    x = torch.cat((x, ext_tensor), dim=0)
                    m = torch.cat((m, ext_tensor), dim=0)
                x = x.unsqueeze(0).unsqueeze(0)
                m = m.unsqueeze(0).unsqueeze(0)
                y = torch.tensor([[label]], dtype=torch.float32)  
                
                x = x.to(device)
                m = m.to(device)
                y = y.to(device)
                
                prediction_seg, prediction_cls = model(x)
                loss_seg = criterion(m, prediction_seg)
                loss_cls = criterion(y, prediction_cls)
                
                loss = alpha*loss_seg + (1-alpha)*loss_cls

                loss.backward()
                optimizer.step()
                model.eval()
                with torch.no_grad():
                    train_acc = ((prediction_cls > .5).float() == y).int().sum().item()
                loop.set_postfix(loss=loss.item(), loss_seg=loss_seg.item(), loss_cls=loss_cls.item())   
                model.train()
                
    print(f"train accuracy:{train_acc/len(train_dataset)}")

random.seed(42)
train_list = [i for i in range(100)]
# random.shuffle(train_list)
train(model, train_list)

  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [6]:
test_acc = 0

for test_data_idx in len(tumors):
    print(f"Using data {test_data_idx} as test data...")
    
    model.eval()
    with torch.no_grad():
        volume, mask, label = tumors[test_data_idx]
        middle_sli = len(volume)//2
        
        x = volume[middle_sli-32:middle_sli+32]
        m = volume[middle_sli-32:middle_sli+32]
        y = torch.tensor([[label]], dtype=torch.float32)  

        x = x.unsqueeze(0)
        m = m.unsqueeze(0)
        
        x = x.to(device)
        m = m.to(device)
        y = y.to(device)
        
        prediction_seg, prediction_cls = model(x)
        prediction_cls = (prediction_cls > .5).float()
        prediction_seg = (prediction_seg > .5).float()
        
        prediction_seg = (m == prediction_seg).int()
        prediction_cls = (y == prediction_cls).int()
        
        test_acc += prediction_cls.sum().item()
        
        
    model.train()

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([1, 64, 128, 128])
torch.Size([1, 64, 128, 128])
torch.Size([1, 1])
