In [1]:
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader
import datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
from models import unet
import torchvision

## Loading the dataset and transformers

In [2]:
train_transforms = A.Compose(
        [
            A.Resize(256, 256),
            A.Rotate(limit=35, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2()
        ],
    )

validation_transforms = A.Compose([
            A.Resize(256, 256),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2()
    ])

In [3]:
train_dataset = datasets.tdsc_2d.TDSC_2D(path="./data/tdsc/slices", train=True, transforms=train_transforms)
validation_dataset = datasets.tdsc_2d.TDSC_2D(path="./data/tdsc/slices", train=False, transforms=validation_transforms)

In [7]:
x,y,l = train_dataset[0]
print(x.shape, y.shape)

torch.Size([1, 256, 256]) torch.Size([256, 256])


## Building the model

In [4]:
class ABUSClassifier(torch.nn.Module):
    
    def __init__(self, device = "cpu"):
        super(ABUSClassifier, self).__init__()
        self.device = device
        # Base Model
        self.base_model = unet.UNet(in_channels=1, out_channels=1)
         # classification
        self.slice_classification_block = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(1024*16*16, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
        )
        
        self.slice_classifier = torch.nn.Linear(256, 1)
        # self.volumeClassifier = VolumeClassifier()
        # self.slices_features = torch.tensor([]).to(device)
    
    # output cls_prediction for slices, segmentation predictions and cls_prediction for volume
    def forward(self, x):
        segmentation_pred = self.base_model(x)
        bottle_neck_features = self.base_model.get_bottleneck_output()
        cls_features = self.slice_classification_block(bottle_neck_features)
        cls_predictions = self.slice_classifier(cls_features)
        
        return segmentation_pred, cls_predictions
      
    # knows that the volume mini-batches are finished and now it can classify then whole volume
    # def end_volume(self):
    #     ret = self.volumeClassifier(self.slices_features)
    #     self.slices_features = torch.tensor([]).to(self.device)
    #     return ret



In [5]:
model = ABUSClassifier()
x = torch.rand([1, 1, 256,256])
print(x.shape)
seg, cls = model(x)
print(seg.shape, cls.shape)

torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256]) torch.Size([1, 1])




## Training hyperparameters

In [5]:
class DiceLoss(torch.nn.Module):
    def init(self):
        super(DiceLoss, self).init()
        
    def forward(self, pred, target):
       smooth = 1.
       iflat = pred.contiguous().view(-1)
       tflat = target.contiguous().view(-1)
       intersection = (iflat * tflat).sum()
       A_sum = torch.sum(iflat * iflat)
       B_sum = torch.sum(tflat * tflat)
       return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth))

In [6]:
mini_batch_size = 8
learning_rate = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"
num_epochs = 10
alpha = 0.3 # how much slice classification is important for us
beta = 0.2 # how much slice segmentation is important for us
gamma = 0.5 # how much volume classification is important for us
criterion_bce = torch.nn.BCEWithLogitsLoss()
criterion_dice = DiceLoss()
model = ABUSClassifier(device=device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, num_workers=4, pin_memory=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=mini_batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [8]:
for data in train_dataloader:
    print(len(data))
    x,y,l = data
    y = y.unsqueeze(1)
    l = l.unsqueeze(1)
    print(x.shape, y.shape, l.shape)
    break

3
torch.Size([8, 1, 500, 500]) torch.Size([8, 1, 500, 500]) torch.Size([8, 1])


In [8]:
def train(model, dataset, optimizer, loss_fun_seg, loss_fun_cls):
    torch.cuda.empty_cache()

    print("training model...")
    model.train()
    loop = tqdm(dataset)
    
    for data in loop:
        x,y,l = data
        
        x = x.to(device)
        y = y.unsqueeze(1).to(device)
        l = l.unsqueeze(1).to(device)
        x = x - x*y*0.3
        # forward
        with torch.cuda.amp.autocast():
            seg_predictions, cls_predictions = model(x)
            seg_predictions = torch.sigmoid(seg_predictions)
            cls_loss = loss_fun_cls(cls_predictions, l)
            seg_loss = loss_fun_seg(seg_predictions, y)
            loss = 0.8*cls_loss + 0.2*seg_loss
        
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

            
        # with torch.no_grad():
        #     cls_preds = torch.sigmoid(cls_predictions)
        #     cls_preds = (cls_preds > 0.5)
        #     num_corrects += (cls_preds == labels).sum()
        
        loop.set_postfix(loss=loss.item())

        # Empty gpu memory
        x = None
        y = None
        l = None
        torch.cuda.empty_cache()
        
        # Here we have to do the vol classification task
        # prediction = model.end_volume().unsqueeze(0)
        # label = torch.tensor([l[0]]).to(device).unsqueeze(0)
        # vol_cls_loss = loss_fn(prediction, label)
        
        # with torch.no_grad():
        #     vol_pred = torch.sigmoid(prediction)
        #     vol_pred = (vol_pred > 0.5)
        #     num_volume_corrects = (prediction == label).sum()
            
            # Total loss is calculated as below


# train(model, train_dataloader, optimizer, criterion_dice, criterion_bce)

In [9]:
def calculate_accuracy(dataset, model, device="cuda"):
    
    print("calculating model accuracy...")
    num_correct = 0
    
    model.eval()
    loop = tqdm(dataset)

    with torch.no_grad():
        for data in loop:
            x, y, l = data
            
            y = y.unsqueeze(1)
            l = l.unsqueeze(1)
            x = x - x*y*0.3
            
            x = x.to(device)
            l = l.to(device)
            
            # forward
            _, cls_predictions = model(x)
            cls_predictions = (torch.sigmoid(cls_predictions) > 0.5).float()
            num_correct += (cls_predictions == l).sum()
            
            x = None
            y = None
            l = None
            torch.cuda.empty_cache()

            loop.set_postfix( acc=(num_correct/len(train_dataset)).item())
    model.train()
    
calculate_accuracy(validation_dataloader, model)

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]



In [10]:
def save_results_as_imgs(model, dataset, path="./saved_images"):
    
    torch.cuda.empty_cache()
    print("Saving the results as images...")
    model.eval()    
    for idx, data in enumerate(dataset):
        x,y,l = data
        y = y.unsqueeze(1)
        x = x.to(device)
        with torch.no_grad():
            segmentation_preds, _  = model(x)
            segmentation_preds = torch.sigmoid(segmentation_preds)
            segmentation_preds = (segmentation_preds > 0.5).float()
            torchvision.utils.save_image(segmentation_preds, f"{path}/prediction.png")
            torchvision.utils.save_image(y, f"{path}/ground_truth.png")
        x = None
        torch.cuda.empty_cache()
        break
    model.train()
    
save_results_as_imgs(model, validation_dataloader)

Saving the results as images...


In [12]:
for i in range(num_epochs):
    train(model, train_dataloader, optimizer, criterion_dice, criterion_bce)
    calculate_accuracy(validation_dataloader, model)
    save_results_as_imgs(model, validation_dataloader)
torch.save(model.state_dict(), "./checkpoint/model.state.pth")

training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...
training model...


  0%|          | 0/397 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/397 [00:00<?, ?it/s]

Saving the results as images...


In [None]:
torch.save(model.state_dict(), "./checkpoint/model.state.pth")