In [None]:
import os
import random
import copy
import torch
import torch.nn as nn
import torchvision
from torch.nn import init
import torchvision.models as models
from torchvision import datasets
import torch.utils.data as data
from torchvision.transforms import transforms 
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from optimizer import LARS
from PIL import Image
import torchvision as tv
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class Config(object):
    def __init__(self):
        self.name = 'CIFAR_AUTO'   
        self.dataset_name = 'CIFAR-10-C'
        self.dataroot ='../DATASETS/'+self.dataset_name
        
        self.save = './CHECKPOINT/' + self.name
        self.model_path = self.save + '/models'
        self.decode_path = self.save + '/decoded_results'
        self.val_path = self.save + '/val_results'
        self.test_path = self.save + '/test_results'
        self.runs = self.save + '/runs'
        
        self.seed =1
        self.input_nc = 3                    # input channel number
        self.output_nc = 3                   # output channel number
        
        #dataset
        self.shuffle_dataset=True
        self.batch_size = 256
        self.input_shape = '32,32,3'
        self.num_workers =16

        # optimization
        self.base_learning_rate = 0.02
        self.learning_rate = self.base_learning_rate/self.batch_size
        self.max_lr = self.learning_rate * self.batch_size
        self.learning_rate_min= 8e-4
        self.Base_momentum=0.99
        self.weight_decay= 5e-4
        self.epochs=1000
        self.warmup_epochs=30                  
        
               
        os.makedirs(self.save, exist_ok=True)
        os.makedirs(self.model_path, exist_ok=True)
        os.makedirs(self.decode_path, exist_ok=True)
        os.makedirs(self.val_path, exist_ok=True)
        os.makedirs(self.test_path, exist_ok=True)
        os.makedirs(self.runs, exist_ok=True)
        
opt=Config()

In [None]:
class MYDS(data.Dataset):
    def __init__(self, noise,transform):
        x_paths = []
        path = opt.dataroot
        
        self.transform=transform
        self.xs = []
        self.ys = None
        
        for n in noise:
            pth = os.path.join(path,n)
            x_paths.append(pth)
            
        for x_path in list(x_paths):
            print(f'loaded: {x_path}')
            self.xs.append(np.load(x_path))
            
#             self.xs = torch.from_numpy(np.load(x_path))
        self.xs = np.array(self.xs)
        self.xss = np.concatenate(self.xs)

    def __len__(self):
        return len(self.xss)

    def __getitem__(self, idx):
        img = self.xss[idx]   
        image = Image.fromarray(img.astype('uint8'), 'RGB')
        image = self.transform(image)
#         print(image)
        return image

In [None]:
noise = ['frost.npy','elastic_transform.npy','impulse_noise.npy','shot_noise.npy','zoom_blur.npy'] # small subset of actual dataset

data_transforms = transforms.Compose([transforms.Resize(size=eval(opt.input_shape)[0]),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize( mean=torch.tensor([0.5204, 0.5169, 0.4858]),
                                                              std=torch.tensor([0.2487, 0.2472, 0.2654]) )
                                        ])

ds = MYDS(noise,data_transforms)
dl = data.DataLoader(ds, batch_size=opt.batch_size, shuffle=opt.shuffle_dataset, num_workers=opt.num_workers,drop_last=True)
print(f'loaded {len(ds)} images')

In [None]:
#  for img in dl:
#         print(img.shape)

In [None]:
def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images in loader:

        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)
# mean, std = online_mean_and_sd(dl)
# print(mean,std)

In [None]:
class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
#         print(x.shape)
        out = self.block(x)
        return out

    
class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
            
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )
    
        self.conv_block = UNetConvBlock(out_size, out_size, padding, batch_norm)
    
    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
    
        return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]
    
    def forward(self, x, bridge):
        try:
            up = self.up(x,output_size=bridge.size())
        except:
            up = self.up(x)    
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(up)        
        return out

In [None]:
class UNet(nn.Module):
    def __init__(self,
                 in_channels=3,                       #in_channels (int): number of input channels
                 out_channels=3,                      #n_classes (int): number of output channels
                 padding=1,     
                 depth=4,                             #depth (int): depth of the network
                 wf=5,
                 batch_norm=False,
                 up_mode='upsample',                   # one of 'upconv' or 'upsample'.
        ):                        
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm))
            prev_channels = 2 ** (wf + i)
        
        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm))
            prev_channels = 2 ** (wf + i)
            
        
        self.last = nn.Conv2d(prev_channels, out_channels, kernel_size=1,stride=1)
                 
            
            
    def forward(self, x):
        blocks = []
#         print('Encoder\n')
        for i, down in enumerate(self.down_path,0):
            x = down(x)
#             print(f'in UNET Down_path Forward {x.shape}')
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)
#                 print(f'after maxpool {x.shape}')
          
#         print('\n\nDecoder\n')
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])
#             print(x.shape)
        
            
        x=self.last(x)
#         print(f'final conv {x.shape}')
        return (x)



In [None]:
network = UNet().to(device)
params=list(network.parameters())
# optimizer = torch.optim.Adam(params,lr=opt.learning_rate,weight_decay=opt.weight_decay)
optimizer = LARS(params, lr=opt.max_lr, momentum=opt.Base_momentum, weight_decay=opt.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(opt.epochs - opt.warmup_epochs - 1), eta_min=opt.learning_rate_min)

In [None]:
network

In [None]:
class Trainer:
    def __init__(self, network, optimizer,scheduler, device, opt):
        self.model = network
        self.optimizer = optimizer
        self.device = device
        self.max_epochs = opt.epochs
        self.writer = SummaryWriter(opt.runs)
        self.batch_size = opt.batch_size
        self.num_workers = opt.num_workers
        self.scheduler=scheduler
        
        self.num_examples = len(dl.dataset)
        self.warmup_steps = opt.warmup_epochs * self.num_examples // opt.batch_size
        self.total_steps = opt.epochs * self.num_examples // opt.batch_size
        


    def _cosine_decay(self,step):
        return 0.5 * opt.max_lr * (1 + np.cos((step - self.warmup_steps) * np.pi / (self.total_steps - self.warmup_steps)))
    
    def update_learning_rate(self, step, decay='poly'):
        """learning rate warm up and decay"""
        if step <= self.warmup_steps:
            lr = opt.max_lr * step / self.warmup_steps
        else:
            lr = self._cosine_decay(step)
            
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            
    def weight_init(self,m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal_(m.weight)
            init.constant_(m.bias, 0)

        
    def train(self,train_loader):
        
        try:
            print("Loading Pretrained models")
            state = torch.load(os.path.join(opt.model_path, 'model.pth'),map_location=device)
            self.model.load_state_dict(state['network_state_dict'])
            self.optimizer.load_state_dict(state['optimizer_state_dict'])
            print("Loaded pre-trained models with success.")
            niter=state['iter']
            epoch_counter=state['epoch']
            loss=state['loss']
            print('NITER: %d | EPOCH: %d | Loss: %.3f '%(niter,epoch_counter,loss))
            
        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch")
            self.model.apply(self.weight_init)
            niter = 0
            epoch_counter=0
            
        cost = torch.nn.MSELoss()
        for epoch in range(epoch_counter,self.max_epochs):
            print()
            print('==================================================================')
            print('-------------Epoch: {}/{}------------'.format(epoch,self.max_epochs))
            epoch_loss=0.0
            for idx,img in enumerate(dl):
                self.update_learning_rate(niter)
                img = img.to(self.device)
#                 print(img.shape)

#                 if niter == 0:
#                     grid = torchvision.utils.make_grid(img[:32])
#                     self.writer.add_image('input_views', grid, global_step=niter)
                
                optimizer.zero_grad()
                out = self.model(img)
                loss = cost(out, img)
                loss.backward()
                optimizer.step()
                epoch_loss+=loss.item()

                
                if (niter+1) % 500 == 0:
                    print('ITER: %d | Loss: %.3f ' %(niter, loss))

                self.writer.add_scalar('loss', loss, global_step=niter)
                niter += 1

            print("End of epoch {}".format(epoch))
            mean_loss = epoch_loss/len(dl)
            print('EPOCH: %d | Loss: %.3f ' %(epoch, mean_loss))
            with open(f'{opt.save}/logs.txt', 'a') as file:
                file.write(str(epoch)+' '+str(mean_loss)+'\n')
            # save checkpoints
            self.save_model(os.path.join(opt.model_path, 'model.pth'),loss,niter,epoch)
            if epoch%50==0:
                self.save_model(os.path.join(opt.model_path, f'model_{epoch}.pth'),loss,niter,epoch)
                filename = 'decoded_%03d.png' % (epoch)
                path = os.path.join(opt.decode_path, filename)
                tv.utils.save_image(out.cpu().data, path, normalize=True)
                print(f'{filename} saved.')
#             if epoch > opt.warmup_epochs:
#                  self.scheduler.step()
#                 lr = self._cosine_decay(step)


    def save_model(self, PATH,loss,niter,epoch):
        torch.save({
            'loss': loss,
            'iter': niter,
            'epoch': epoch,
            'network_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, PATH)

In [None]:
trainer = Trainer(network=network,
                      optimizer=optimizer,
                      device=device,
                      scheduler=scheduler,
                      opt=opt)

In [None]:
trainer.train(dl)