In [None]:
import sys, os
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

sys.path.append('/content/drive/My Drive/Radiomics Workshop')
os.chdir('drive/My Drive/Radiomics Workshop')

Mounted at /content/drive


In [None]:
from mice_dataset import PatchMiceDatasetFromTensor
from torch.utils.data import DataLoader

bs = 10

dataset_name = 'dataForNet_shuffle_by_slices_2d_600_30per'

train_dataset = PatchMiceDatasetFromTensor(f'datasets/{dataset_name}') #, max_dataset_size=10)
test_dataset = PatchMiceDatasetFromTensor(f'datasets/{dataset_name}', is_train=False)

dl_train = DataLoader(train_dataset, batch_size=bs, shuffle=True)
dl_test = DataLoader(test_dataset, batch_size=bs, shuffle=False)

print(len(dl_train), len(dl_test))

In [None]:
import numpy as np
import tqdm
import torch
import itertools
from network import UNet
from network import dice_loss
import utils

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

model_name = 'best'
root_path = 'results/models/segmentation'
model_path = f'{root_path}/{model_name}'

In [None]:

use_location = False
location = None
save = True
num_epochs = 100
train_params =   {
    'device': device,
    'n_channels': 1,
    'n_classes': 1,
    'criterion': dice_loss,
    'optimizer': torch.optim.Adam,
    'lr': 0.0002,
    'layers': [16, 32, 64, 128],
    "scheduler_params" : {
                    "total_epochs_num": num_epochs,
                    "static_epochs_num": num_epochs / 5
                    },
}

model = UNet(**train_params).to(device)
utils.init_model_weights(model)
print(model)

train_loss_per_epoch, test_loss_per_epoch, accuracy_per_epoch = [], [], []
min_iou_epoch = 0
min_test_loss_epoch = float("inf")

for epoch in range(num_epochs):
  train_loss_per_batch, test_loss_per_batch = [], []
  iou_per_epoch = 0
  print(f'--- EPOCH {epoch + 1}/{num_epochs} ---')

  #Train
  with tqdm.tqdm(total=len(dl_train), file=sys.stdout) as pbar:
    for batch in dl_train:
      image = batch['image'].to(device, dtype=torch.float32)
      mask = batch['mask'].to(device, dtype=torch.float32)
      loss = model.train_batch(image, mask)
      train_loss_per_batch.append(loss.item())
      pbar.update()
    train_loss_per_epoch.append(np.mean(train_loss_per_batch))
    model.update_learning_rate()
  
  #Test
  with tqdm.tqdm(total=len(dl_test), file=sys.stdout) as pbar:
    for batch in dl_test:
      image = batch['image'].to(device, dtype=torch.float32)
      mask = batch['mask'].to(device, dtype=torch.float32)
      cross_entropy_loss, cur_mistakes_num = model.test_batch(image, mask)
      iou_per_epoch += cur_mistakes_num
      test_loss_per_batch.append(cross_entropy_loss)
      pbar.update()

  cur_epoch_test_loss = np.mean(test_loss_per_batch)
  test_loss_per_epoch.append(cur_epoch_test_loss)
  
  print("Epoch", epoch)
  print("Train loss", train_loss_per_epoch[-1])
  print("Test loss", test_loss_per_epoch[-1])
  total_iou = iou_per_epoch / (len(dl_test))
  accuracy_per_epoch.append(total_iou)
  print("Test Acurracy", total_iou)
  if cur_epoch_test_loss < min_test_loss_epoch:
    min_test_loss_epoch = cur_epoch_test_loss

    if save:
      if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok=True)
      print(f'**** Saving in epoch {epoch + 1} *****')
      saved_state = dict(test_losses=test_loss_per_epoch,
                          train_losses=train_loss_per_epoch,
                          test_accuracies=accuracy_per_epoch,
                          model_state=model.state_dict(),
                          dataset_name=dataset_name,
                          train_params=train_params
                        )
      torch.save(saved_state, f"{model_path}/model_{epoch + 1}_epochs")
      best_state_dict = model.state_dict()

# Load Model
load best model and print patches

In [None]:
from utils import tensors_as_images

epoch_name = "model_67_epochs"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
params = torch.load(f"{model_path}/{epoch_name}", map_location=device)
train_params = params['train_params']
model_state = params['model_state']
model = UNet(**train_params).to(device)
model.load_state_dict(model_state)
model.to(device)

# model.load_state_dict(best_state_dict)
count = 0


for i, batch in enumerate(dl_test):
  image = batch['image'].to(device)
  mask = batch['mask'].to(device)
  if batch['label'][0] == 0:
    continue
  count += 1
  cross_entropy_loss, cur_mistakes_num = model.test_batch(image, mask)
  with torch.no_grad():
    output = model(image).cpu()
    # output_mask = torch.where(output[0] > 0.5, torch.ones(output[0].size()), torch.zeros(output[0].size()))
    # output = torch.sigmoid(output)
    tensors_as_images([image[0].cpu(), mask[0].cpu(), output[0]])
  if count == 50:
    break