In [None]:
import torch
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np

class Trainer(object):
    def __init__(self, model,
                 optimizer, loss_function,
                 loader_train, loader_val,
                 dtype, device, **in_params):
        """
        :param model: PyTorch model of the neural network

        :param optimizer: PyTorch optimizer

        :param print_every: How often should we print the loss during training
        """
        # Create attributes:
        self.device = device
        self.model = model.to(device=self.device)  # move the model parameters to CPU/GPU
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.loader_train = loader_train
        self.loader_val = loader_val
        self.print_every = in_params["print_every"]
        self.dtype = dtype
        self.batch_size = in_params["batch_size"]
        self.input_size = in_params["input_size"]
        self.path = in_params["path"]
        self.collect_test_loss=[]
        self.collect_train_loss=[]


    def train_model(self, epoch):
        """
        - epoch: An integer giving the epoch
        """
        train_loss = 0
        self.model.train()  # put model to training mode
        for t, input in enumerate(self.loader_train):
            
            input = input.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
            
            # do a step in training
            args = self.model(input)
            loss = self.loss_function(*args,**{'M_N':1e-7*self.batch_size/len(self.loader_train)})['loss']
            self.optimizer.zero_grad()
            loss.backward() 
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),1)
            train_loss += loss.item() # accumulate for average loss
            self.optimizer.step()

            # print loss
            if t % self.print_every == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, t * len(args[1]), len(self.loader_train.dataset),
                    100. * t / len(self.loader_train),
                    loss.item() / len(args[1])))
        # print average loss
        self.collect_train_loss.append(train_loss / len(self.loader_train.dataset))
        print('====> Epoch: {} Average loss: {:.6f}'.format(
              epoch, train_loss / len(self.loader_train.dataset)))

    def test_model(self, epoch):
        self.model.eval() # Put model to evaluation mode
        test_loss = 0.

        with torch.no_grad():
            # During validation, we accumulate these values across the whole dataset and then average at the end:
            for i, input in enumerate(self.loader_val):
                input = input.to(device=self.device, dtype=self.dtype)  # move to device, e.g. GPU
       
                # compute loss and accumulate
                args = self.model(input)
                test_loss += self.loss_function(*args,**{'M_N':1e-7*self.batch_size/len(self.loader_val)})['loss'].item()
                if i == 0 and epoch%10 == 0:
                    n=1
                    original_img = args[1][:n]
                    original_img = torch.reshape(original_img, (9, 3, self.input_size, self.input_size))
                    reconstructed_img = args[0].view(self.batch_size, self.model.out_channels, self.input_size, self.input_size)[:n]
                    reconstructed_img = torch.reshape(reconstructed_img, (9, 3, self.input_size, self.input_size))
                    comparison = torch.cat([original_img, reconstructed_img],-1)
                    save_image(comparison.cpu(),
                             self.path + '/reconstruction_' + str(epoch) + '.png', nrow=n)

        # print average loss
        test_loss /= len(self.loader_val.dataset)
        self.collect_test_loss.append(test_loss)
        print('====> Test set loss: {:.6f}'.format(test_loss))
        
    def train_and_test(self, epochs, path1, path2):

        for e in range(1,epochs+1):
            self.train_model(e)
            self.test_model(e)
            if e%10 == 0:
                with torch.no_grad():
                    sample = self.model.sample(64, device)
                    save_image(sample[:, :3, :, :], self.path + '/sample_' + str(e) + '.png')
        # Print and save loss plots
        with torch.no_grad():
          trloss = self.collect_train_loss
          teloss = self.collect_test_loss
          print(np.min(trloss),np.min(teloss))

          n1 = len(trloss)
          n2 = len(teloss)

          plt.yscale('log')
          plt.plot(np.arange(n1), trloss)
          plt.grid()
          plt.title('average train loss')
          plt.savefig(path1)
          plt.show()

          plt.yscale('log')
          plt.plot(np.arange(n2), teloss)
          plt.grid()
          plt.title('test loss')
          plt.savefig(path2)
          plt.show()

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models.data_loading import LightFieldDataset
from models.VAE_changed import VAE
#from models.Trainer import Trainer
from models.transformations import * 
import os

# new directory
res_folder = 'results'
os.makedirs(res_folder, exist_ok=True)
model_path = os.path.join(res_folder, 'model.pth')

# set parameters
in_params = {"batch_size": 40,
        "epochs": 1500,
        "no_cuda": False,
        "seed": 1,
        "print_every": 20,
        "input_size": 64,
        "path": res_folder,
        "in_channels": 9*3
        }
in_params["cuda"] = not in_params["no_cuda"] and torch.cuda.is_available()
torch.manual_seed(in_params["seed"])

device = torch.device("cuda" if in_params["cuda"] else "cpu")

kwargs = {'num_workers': 4, 'pin_memory': True} if in_params["cuda"] else {}


# transformations of input images before feeding into nn
transformation_list = [Brightness(1), Color_jitter(1), Contrast(1), Noise(1)]
transformations = transforms.Compose([RandomCrop(in_params['input_size']),
                                     transforms.RandomChoice(transformation_list)])

# load data into DataLoader
train_set = LightFieldDataset(sort=['training','stratified','additional'], 
                              data_kind = 'stack', 
                              root_dir = 'data', 
                              transform = transformations)
test_set = LightFieldDataset(sort=['test'], 
                              data_kind = 'stack', 
                              root_dir = 'data', 
                              transform = RandomCrop(in_params['input_size']))
train_loader = DataLoader(train_set,
                           batch_size=in_params["batch_size"], 
                           shuffle=True,
                           drop_last=True,
                           **kwargs
                           )

test_loader = DataLoader(test_set,
                          batch_size=in_params["batch_size"], 
                          shuffle=True,
                          drop_last=True,
                          **kwargs
                          )

# Create Model
model = VAE(in_channels=in_params['in_channels'],
           in_size=in_params["input_size"],
           hidden_dims=[64,128,256])


# Build the optimizer:
params = model.parameters()
learning_rate = 1e-4
optimizer = torch.optim.AdamW(params, lr=learning_rate)


# Build the trainer with the Soresen-Dice loss you implemented:
trainer = Trainer(model, optimizer, model.loss_function,
        train_loader, test_loader, torch.float32, device,**in_params )



# Start training:
path1 = os.path.join(res_folder, 'loss_train')
path2 = os.path.join(res_folder, 'loss_test')
trainer.train_and_test(in_params["epochs"], path1, path2)

# Save trained model
torch.save({'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': in_params["epochs"]}, model_path)