In [20]:
import argparse
import logging
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from monai.networks.nets import BasicUNet as BU

import import_ipynb
%run DiceLoss.ipynb import DiceLoss
import DosePredictionDataset


import numpy as np

In [21]:
def train_net(net,
              device,
              input_dir,
              dose_dir,
              epochs: int = 5,
              batch_size: int = 1,
              drop_out: int=0.5,
              learning_rate: float = 0.001,
              val_percent: float = 0.1,
              ):

    # create dataloaders for training and validation
    dataset = DosePredictionDataset.DosePrdictionDataset(input_dir, dose_dir)
    val_num = int(len(dataset) * val_percent)
    train_num = len(dataset) - val_num
    train_set, val_set = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(42))    
    loader_args = dict(batch_size=batch_size, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)


    # Set up optimizer/loss/learning rate/scheduler
    optimizer = optim.RMSprop(net.parameters(), 
                              lr=learning_rate, 
                              weight_decay=1e-8, 
                              momentum=0.9)

    # Training begins
    for epoch in range(epochs):
        epoch_loss_train = 0
        epoch_loss_val = 0
        with tqdm(total=train_num, desc=f'Epoch {epoch + 1}/{epochs}', unit="images") as progressBar:
            # train
            net.train()
            epoch_loss = 0
            for images, labels in train_loader:
                images = images.to(device=device, dtype=torch.float32)
                labels = labels.to(device=device, dtype=torch.float32)
                image_pred = net(images)
                loss = DiceLoss.DiceLoss().forward(image_pred,labels)
                epoch_loss_train += loss
                optimizer.zero_grad()
                loss.backwards()
                optimizer.step()
                progressBar.update(images.shape[0])
                progressBar.set_postfix(**{'loss (batch)': loss.item()})
            self.train_dice_loss.append(epoch_loss_train)

            # validation
            with torch.no_grad():
                for images, labels in val_loader:
                    images = images.to(device=device, dtype=torch.float32)
                    labels = labels.to(device=device, dtype=torch.float32)
                    image_pred = net(images)
                    loss = DiceLoss().forward(image_pred,labels)
                    epoch_loss_val += loss
                self.val_dice_loss.append(epoch_loss_val) 
                

In [71]:
input_dir = "/Users/wangyangwu/Documents/Maastro/NeuralNets/BasicUNet/sample/input/"
label_dir = "/Users/wangyangwu/Documents/Maastro/NeuralNets/BasicUNet/sample/label/"

pathlist_input = Path(input_dir).rglob('*.csv')
pathlist_label = Path(label_dir).rglob('*.csv')
input_paths = [str(path) for path in pathlist_input]
label_paths = [str(path) for path in pathlist_label]

In [72]:
net = BU(spatial_dims=2,in_channels=1, out_channels=1, features=(64, 128, 256, 512, 1024, 1))
train_net(net, "cpu", input_paths, label_paths)

BasicUNet features: (64, 128, 256, 512, 1024, 1).


RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 1, 3, 3], but got 3-dimensional input of size [1, 512, 512] instead

In [6]:
import wandb
wandb.login()

True