In [1]:
import torch
import torchvision
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
import datasets
from tqdm.notebook import tqdm

In [2]:
train_transforms = A.Compose(
        [
            A.Rotate(limit=35, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Resize(224, 224),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2()
        ],
    )

validation_transforms = A.Compose([
            A.Resize(224, 224),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2()
    ])

In [7]:
train_dataset = datasets.tdsc_2d.TDSC_2D(path="./data/tdsc/slices", train=True, transforms=train_transforms)
validation_dataset = datasets.tdsc_2d.TDSC_2D(path="./data/tdsc/slices", train=False, transforms=validation_transforms)

In [8]:
x,y,l = train_dataset[0]
print(x.shape)
print(y.shape)
print(l)

torch.Size([3, 224, 224])
torch.Size([224, 224])
1.0


In [19]:
class VGG16(torch.nn.Module):
    
    def __init__(self):
        super(VGG16, self).__init__()
        self.base_model = torchvision.models.vgg16(pretrained=True)
        self.base_model.classifier = torch.nn.Identity()
        # in_features = self.base_model.classifier[-1].in_features
        # print(in_features)
        # model.classifier[-1] = torch.nn.Linear(in_features, 1)
        # if weights are available load the weights...
        
    def forward(self, x):
        return self.base_model(x)
    



model = VGG16()
x = torch.rand([1, 3, 224, 224], dtype=torch.float32)
print(x.shape)

p = model(x)
print(p.shape)



torch.Size([1, 3, 224, 224])
torch.Size([1, 25088])


In [None]:
mini_batch_size = 16
learning_rate = 1e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
num_epochs = 10
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()

In [11]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True, num_workers=4, pin_memory=True)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=mini_batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [12]:
def calculate_accuracy(dataset, model, device="cuda"):
    
    print("calculating model accuracy...")
    num_correct = 0
    
    model.eval()
    loop = tqdm(dataset)

    with torch.no_grad():
        for data in loop:
            x, m, y = data
            y = y.unsqueeze(1)
            m = m.unsqueeze(1)
            
            x = x-x*m*0.5
            x = x.to(device)
            y = y.to(device)
            
            # forward
            cls_predictions = model(x)
            cls_predictions = (torch.sigmoid(cls_predictions) > 0.5).float()
            num_correct += (cls_predictions == y).sum()

            loop.set_postfix( acc=(num_correct/len(train_dataset)).item())
    model.train()
    
calculate_accuracy(validation_dataloader, model)

calculating model accuracy...


  0%|          | 0/199 [00:00<?, ?it/s]

In [13]:
for epoch in range(num_epochs):
    loop = tqdm(train_dataloader)
    for data in loop:
        x, m, y = data
        y = y.unsqueeze(1)
        m = m.unsqueeze(1)
        
        x = x-x*m*0.5
        x = x.to(device)
        y = y.to(device)
        
        with torch.cuda.amp.autocast():
            predictions = model(x)
            predictions = torch.sigmoid(predictions)
            loss = criterion(predictions, y)
        
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()            
        
        loop.set_postfix(loss=loss.item())
    calculate_accuracy(validation_dataloader, model)

  0%|          | 0/199 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/199 [00:00<?, ?it/s]

  0%|          | 0/199 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/199 [00:00<?, ?it/s]

  0%|          | 0/199 [00:00<?, ?it/s]

calculating model accuracy...


  0%|          | 0/199 [00:00<?, ?it/s]

  0%|          | 0/199 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
torch.save(model.state_dict(), "./checkpoint/vgg16.state.pth")

## Volume Classification using GRU