In [1]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
from torchsummary import summary

from oct_dataset import OCTDataset

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
print(device)

cuda


In [4]:
!kill 32120

kill: 32120: No such process


In [3]:
%load_ext tensorboard
%tensorboard --logdir logs --port 6006

Reusing TensorBoard on port 6006 (pid 32120), started 16:49:37 ago. (Use '!kill 32120' to kill it.)

In [None]:
transforms = transforms.Compose([transforms.ToTensor()])

In [None]:
hparams = {
    "batch_size": 16,
    "learning_rate": 1e-3,
    "input_size": 1 * 1024 * 512,
    "in_channels": 1,
    "out_channels": 5,
    "device": device,
    "epochs": 50,
    "weight_decay": 1e-6,
}

In [None]:
train_dataset = OCTDataset(root_dir='data/train_data', transform=transforms)
test_dataset = OCTDataset(root_dir='data/test_data', transform=transforms)
val_dataset = OCTDataset(root_dir='data/val_data', transform=transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams['batch_size'], shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)

In [None]:
for images, masks in train_dataloader:
    print(images.shape)
    print(masks.shape)
    break

In [None]:
def visualize_data(dataloader):
    for batch in dataloader:
        images, masks = batch

        image = images[0].permute(1, 2, 0).numpy()
        mask = masks[0].permute(1, 2, 0).numpy()

        # plt.figure(figsize=(10, 5))

        # Plot the image
        plt.subplot(1, 2, 1)
        plt.imshow(image, cmap='gray')
        plt.title('Image')
        plt.axis('off')

        # Plot the segmentation mask
        plt.subplot(1, 2, 1)
        plt.imshow(mask, cmap='viridis', alpha=0.5)  # Adjust cmap based on your segmentation task
        plt.title('Segmentation Mask')
        plt.axis('off')

        plt.show()
        break

visualize_data(train_dataloader)
visualize_data(test_dataloader)
visualize_data(val_dataloader)

In [None]:
from seg_model import UNet

model = UNet(hparams["in_channels"], hparams['out_channels'])
model.to(device)

summary(model, input_size=(1, 1024, 512), batch_size=32)

In [None]:
def train(model, train_dataloader, val_dataloader, criterion, optimizer, device, num_epochs):
    model.to(device)
    model.train()
    
    for epoch in range(num_epochs):
        train_loss = 0.0
        val_loss = 0.0
        
        # Training step
        for images, masks in train_dataloader:
            images = images.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            
            loss = criterion(outputs, masks)
            
            loss.backward()
            
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_dataloader)
        
        # Validation step
        model.eval()
        with torch.no_grad():
            for images, masks in val_dataloader:
                images = images.to(device)
                masks = masks.to(device)
                
                outputs = model(images)
                
                loss = criterion(outputs, masks)
                
                val_loss += loss.item()
            
            val_loss /= len(val_dataloader)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    print("Training complete!")


In [None]:
def save_model(model, path=r'D:\Desktop\demir\oct_segmentation\saved_models', name='default'):
    version_folder_number = 1
    while True:
        version_folder = os.path.join(path, f'version{version_folder_number}')
        if not os.path.exists(version_folder):
            os.makedirs(version_folder)
            break
        version_folder_number += 1

    # Save the file inside the test folder
    save_path = os.path.join(version_folder, f'final_{name}.pth')
    torch.save(model.state_dict(), save_path)

In [None]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
 

def create_tqdm_bar(iterable, desc):
    return tqdm(enumerate(iterable),total=len(iterable), ncols=150, desc=desc)


def train_model(model, train_loader, val_loader, loss_func, tb_logger, optimizer, epochs=10, name="default", save_path=r'D:\Desktop\demir\oct_segmentation\saved_models'):
    """
    Train the classifier for a number of epochs.
    """
    torch.autograd.set_detect_anomaly(True)
    loss_cutoff = len(train_loader) // 10
    
    # The scheduler is used to change the learning rate every few "n" steps.
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * len(train_loader) / 5), gamma=hparams.get('gamma', 0.8))
    
    for epoch in range(epochs):
        
        model.train() 
        
        training_loss = []
        validation_loss = []
        
        # Create a progress bar for the training loop.
        training_loop = create_tqdm_bar(train_loader, desc=f'Training Epoch [{epoch + 1}/{epochs}]')
        for train_iteration, batch in training_loop:
            optimizer.zero_grad() 
            images, labels = batch
            labels = labels.squeeze().long()
            images, labels = images.to(device), labels.to(device) 

            pred = model(images)
            loss = loss_func(pred, labels) 
            loss.backward()
            optimizer.step()
            scheduler.step()

            training_loss.append(loss.item())
            training_loss = training_loss[-loss_cutoff:]

            training_loop.set_postfix(curr_train_loss = "{:.8f}".format(np.mean(training_loss)), 
                                      lr = "{:.8f}".format(optimizer.param_groups[0]['lr'])
            )

            tb_logger.add_scalar(f'classifier_{name}/train_loss', loss.item(), epoch * len(train_loader) + train_iteration)


        model.eval()
        val_loop = create_tqdm_bar(val_loader, desc=f'Validation Epoch [{epoch + 1}/{epochs}]')
        
        with torch.no_grad():
            for val_iteration, batch in val_loop:
                images, labels = batch
                labels = labels.squeeze().long()
                images, labels = images.to(device), labels.to(device)

                pred = model(images)
                loss = loss_func(pred, labels)
                validation_loss.append(loss.item())

                val_loop.set_postfix(val_loss = "{:.8f}".format(np.mean(validation_loss)))

                tb_logger.add_scalar(f'classifier_{name}/val_loss', loss.item(), epoch * len(val_loader) + val_iteration)
        
        if epoch % 10 == 0 and epoch != 0:
            save_model_path = os.path.join(save_path, f'{name}_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), save_model_path)
        
    save_model(model, save_path, name)
    return model
        

In [None]:
def test(model, dataloader):
    test_scores = []
    model.eval()
    model = model.to(device)
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model.forward(inputs)
        _, preds = torch.max(outputs, 1)
        targets_mask = (targets >= 0).cpu()
        test_scores.append(np.mean((preds.cpu() == targets.cpu())[targets_mask].numpy()))

    return np.mean(test_scores)

In [None]:
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['learning_rate']) #, weight_decay=hparams['weight_decay']
path = os.path.join('logs')
num_of_runs = len(os.listdir(path)) if os.path.exists(path) else 0
path = os.path.join(path, f'run_{num_of_runs + 1}')
tb_logger = SummaryWriter(path)

In [None]:
# sanity check make small dataset see if network can overfit
index_list = [100, 32, 326]
small_trainset = torch.utils.data.Subset(train_dataset, index_list)
small_trainloader = torch.utils.data.DataLoader(small_trainset, batch_size=3, shuffle=True)
small_valset = torch.utils.data.Subset(val_dataset, index_list)
small_valloader = torch.utils.data.DataLoader(small_valset, batch_size=3, shuffle=True)

trained_test_model = train_model(model=model, train_loader=small_trainloader, val_loader=small_valloader, loss_func=loss, tb_logger=tb_logger, optimizer=optimizer, epochs=10, name='sanity check')

In [None]:
trained_model = train_model(model=model, train_loader=train_dataloader, val_loader=val_dataloader, loss_func=loss, tb_logger=tb_logger, optimizer=optimizer, epochs=40, name='oct_seg_v1')