# <font color="blue">**CS685  : Project**</font> 

###Self-Supervised U-Net model implementation and weight saving

In [None]:
## mounting the google drive to load the data from directory
# from google.colab import drive
# drive.mount('/content/drive')
from google.colab import drive
drive.mount('/content/drive', force_remount=True) 

In [None]:
import tensorflow as tf
tf.test.gpu_device_name()

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

#PART 1: Self Supervised Learning

##Utils

In [None]:
from __future__ import print_function
import math
import os
import random
import copy
import scipy
import imageio
import string
import numpy as np
from datetime import datetime
import pytz
from skimage.transform import resize
try:  # SciPy >= 0.19
    from scipy.special import comb
except ImportError:
    from scipy.misc import comb

def bernstein_poly(i, n, t):
    """
     The Bernstein polynomial of n, i as a function of t
    """

    return comb(n, i) * ( t**(n-i) ) * (1 - t)**i

def bezier_curve(points, nTimes=1000):
    """
       Given a set of control points, return the
       bezier curve defined by the control points.

       Control points should be a list of lists, or list of tuples
       such as [ [1,1], 
                 [2,3], 
                 [4,5], ..[Xn, Yn] ]
        nTimes is the number of time steps, defaults to 1000

        See http://processingjs.nihongoresources.com/bezierinfo/
    """

    nPoints = len(points)
    xPoints = np.array([p[0] for p in points])
    yPoints = np.array([p[1] for p in points])

    t = np.linspace(0.0, 1.0, nTimes)

    polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)   ])
    
    xvals = np.dot(xPoints, polynomial_array)
    yvals = np.dot(yPoints, polynomial_array)

    return xvals, yvals

def data_augmentation(x, y, prob=0.5):
    # augmentation by flipping
    cnt = 3
    while random.random() < prob and cnt > 0:
        degree = random.choice([0, 1, 2])
        x = np.flip(x, axis=degree)
        y = np.flip(y, axis=degree)
        cnt = cnt - 1

    return x, y

def nonlinear_transformation(x, prob=0.5):
    if random.random() >= prob:
        return x
    points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]]
    xpoints = [p[0] for p in points]
    ypoints = [p[1] for p in points]
    xvals, yvals = bezier_curve(points, nTimes=100000)
    if random.random() < 0.5:
        # Half change to get flip
        xvals = np.sort(xvals)
    else:
        xvals, yvals = np.sort(xvals), np.sort(yvals)
    nonlinear_x = np.interp(x, xvals, yvals)
    return nonlinear_x

def local_pixel_shuffling(x, prob=0.5):
    if random.random() >= prob:
        return x
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)
    _, img_rows, img_cols = x.shape
    num_block = 10000
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows//10)
        block_noise_size_y = random.randint(1, img_cols//10)
        noise_x = random.randint(0, img_rows-block_noise_size_x)
        noise_y = random.randint(0, img_cols-block_noise_size_y)
        window = orig_image[0, noise_x:noise_x+block_noise_size_x, 
                               noise_y:noise_y+block_noise_size_y, 
                           ]
        window = window.flatten()
        np.random.shuffle(window)
        window = window.reshape((block_noise_size_x, 
                                 block_noise_size_y))
        image_temp[0, noise_x:noise_x+block_noise_size_x, 
                      noise_y:noise_y+block_noise_size_y] = window
    local_shuffling_x = image_temp

    return local_shuffling_x

def image_in_painting(x):
    _, img_rows, img_cols = x.shape
    cnt = 5
    while cnt > 0 and random.random() < 0.95:
        block_noise_size_x = random.randint(img_rows//6, img_rows//3)
        block_noise_size_y = random.randint(img_cols//6, img_cols//3)
        noise_x = random.randint(3, img_rows-block_noise_size_x-3)
        noise_y = random.randint(3, img_cols-block_noise_size_y-3)
        x[:, 
          noise_x:noise_x+block_noise_size_x, 
          noise_y:noise_y+block_noise_size_y] = np.random.rand(block_noise_size_x, 
                                                               block_noise_size_y, ) * 1.0
        cnt -= 1
    return x

def image_out_painting(x):
    _, img_rows, img_cols = x.shape
    image_temp = copy.deepcopy(x)
    x = np.random.rand(x.shape[0], x.shape[1], x.shape[2], ) * 1.0
    block_noise_size_x = img_rows - random.randint(3*img_rows//7, 4*img_rows//7)
    block_noise_size_y = img_cols - random.randint(3*img_cols//7, 4*img_cols//7)
    noise_x = random.randint(3, img_rows-block_noise_size_x-3)
    noise_y = random.randint(3, img_cols-block_noise_size_y-3)
    x[:, 
      noise_x:noise_x+block_noise_size_x, 
      noise_y:noise_y+block_noise_size_y] = image_temp[:, noise_x:noise_x+block_noise_size_x, 
                                                       noise_y:noise_y+block_noise_size_y]
    cnt = 4
    while cnt > 0 and random.random() < 0.95:
        block_noise_size_x = img_rows - random.randint(3*img_rows//7, 4*img_rows//7)
        block_noise_size_y = img_cols - random.randint(3*img_cols//7, 4*img_cols//7)
        noise_x = random.randint(3, img_rows-block_noise_size_x-3)
        noise_y = random.randint(3, img_cols-block_noise_size_y-3)
        x[:, 
          noise_x:noise_x+block_noise_size_x, 
          noise_y:noise_y+block_noise_size_y] = image_temp[:, noise_x:noise_x+block_noise_size_x, 
                                                           noise_y:noise_y+block_noise_size_y]
        cnt -= 1
    return x
                


def generate_pair(img, batch_size, config, status="test"):
    img_rows, img_cols = img.shape[2], img.shape[3]
    while True:
        index = [i for i in range(img.shape[0])]
        random.shuffle(index)
        y = img[index[:batch_size]]
        x = copy.deepcopy(y)
        for n in range(batch_size):
            
            # Autoencoder
            x[n] = copy.deepcopy(y[n])
            
            # Flip
            x[n], y[n] = data_augmentation(x[n], y[n], config.flip_rate)

            # Local Shuffle Pixel
            x[n] = local_pixel_shuffling(x[n], prob=config.local_rate)
            
            # Apply non-Linear transformation with an assigned probability
            x[n] = nonlinear_transformation(x[n], config.nonlinear_rate)
            
            # Inpainting & Outpainting
            if random.random() < config.paint_rate:
                if random.random() < config.inpaint_rate:
                    # Inpainting
                    x[n] = image_in_painting(x[n])
                else:
                    # Outpainting
                    x[n] = image_out_painting(x[n])

        # Save sample images module
        if config.save_samples is not None and status == "train" and random.random() < 0.01:
            n_sample = random.choice( [i for i in range(config.batch_size)] )
            final_sample = np.concatenate((x[n_sample,0,:,:], y[n_sample,0,:,:]), axis=1)
            final_sample = final_sample * 255.0
            final_sample = final_sample.astype(np.uint8)
            file_name = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(10)])+'.'+config.save_samples
            imageio.imwrite(os.path.join(config.sample_path, config.exp_name, file_name), final_sample)

        yield (x, y)

        


#UNET

In [None]:
""" Full assembly of the parts to form the complete network """

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.conv(x))



class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
    
        #per_out=[]
        
        x1 = self.inc(x)
        #per_out.append(x1) # conv1
        
        x2 = self.down1(x1)
        #per_out.append(x2) # down1
        
        x3 = self.down2(x2)
        #per_out.append(x3) # down2
        
        x4 = self.down3(x3)
        #per_out.append(x4) # down3
        
        x5 = self.down4(x4)
        #per_out.append(x5) # down4
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return logits

    
class UNet_hidden(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet_hidden, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
    
        #per_out=[]
        
        x1 = self.inc(x)
        #per_out.append(x1) # conv1
        
        x2 = self.down1(x1)
        #per_out.append(x2) # down1
        
        x3 = self.down2(x2)
        #per_out.append(x3) # down2
        
        x4 = self.down3(x3)
        #per_out.append(x4) # down3
        
        x5 = self.down4(x4)
        #per_out.append(x5) # down4
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        
        return logits, x5

##Swin_UNET

In [None]:
# import torch
# import torch.nn as nn
# # from torchvision.models.utils import load_state_dict_from_url

# # Swin Transformer Encoder Block
# class SwinTransformerEncoderBlock(nn.Module):
#     def __init__(self, embed_dim, num_heads, window_size, shift_size, mlp_ratio=4.0):
#         super().__init__()
#         self.norm1 = nn.LayerNorm(embed_dim)
#         self.attn = nn.MultiheadAttention(embed_dim, num_heads)
#         self.norm2 = nn.LayerNorm(embed_dim)
#         self.mlp = nn.Sequential(
#             nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
#             nn.GELU(),
#             nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
#         )

#         # Window partitioning and shifting
#         self.window_size = window_size
#         self.shift_size = shift_size

#     def forward(self, x):
#         # Apply layer normalization
#         x = self.norm1(x)

#         # Apply self-attention
#         x = x.permute(2, 0, 1)
#         x, _ = self.attn(x, x, x)
#         x = x.permute(1, 2, 0)

#         # Apply residual connection and layer normalization
#         x = x + x.permute(0, 2, 1)
#         x = self.norm2(x)

#         # Apply MLP
#         y = self.mlp(x)

#         # Partition into non-overlapping windows and shift
#         B, N, C = y.shape
#         h = self.window_size
#         w = N // h
#         y = y.view(B, h, w, C)
#         y = y.permute(0, 3, 1, 2)
#         y = torch.nn.functional.pad(y, (0, 0, 0, 0, self.shift_size // 2, self.shift_size // 2), mode="constant")
#         y = y.reshape(B, C, h * (w + 2 * (self.shift_size // 2)))
#         y = y.permute(0, 2, 1)

#         # Apply residual connection
#         y = y + x

#         return y

# # Swin UNet Model
# class SwinUNet(nn.Module):
#     def __init__(self, input_channels, num_classes):
#         super().__init__()

#         # Swin Transformer Encoder Blocks
#         self.enc1 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0)
#         self.enc2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0)
#         self.enc3 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0)
#         self.enc4 = SwinTransformerEncoderBlock(embed_dim=512, num_heads=16, window_size=7, shift_size=0)

#         # Swin Transformer Decoder Blocks
#         self.dec1 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0)
#         self.dec2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0)
#         self.dec3 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0)

#         # Final convolutional layer for segmentation output
#         self.final_conv = nn.Conv2d(96, num_classes, kernel_size=1, stride=1, padding=0)

#     def forward(self, x):
#         # Encoder
#         x1 = self.enc1(x)
#         x2 = self.enc2(x1)
#         x3 = self.enc3(x2)
#         x4 = self.enc4(x3)

#         # Decoder
#         y1 = self.dec1(x4) + nn.functional.interpolate(x4, scale_factor=2, mode="nearest")
#         y2 = self.dec2(y1) + nn.functional.interpolate(y1, scale_factor=2, mode="nearest")
#         y3 = self.dec3(y2) + nn.functional.interpolate(y2, scale_factor=2, mode="nearest")

#         # Output
#         y = self.final_conv(y3)
#         return y


In [None]:
import torch
import torch.nn as nn
# from torchvision.models.utils import load_state_dict_from_url

# Swin Transformer Encoder Block
class SwinTransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size, shift_size, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim,(256, 256))
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim,(256, 256))
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
        )

        # Window partitioning and shifting
        self.window_size = window_size
        self.shift_size = shift_size

    def forward(self, x):
        # Apply layer normalization
        x = self.norm1(x)

        # Apply self-attention
        x = x.permute(0, 2, 3, 1)
        B, H, W, C = x.shape
        x = x.reshape(B * H * W, C, 1)
        x, _ = self.attn(x, x, x)
        x = x.reshape(B, H, W, C)
        x = x.permute(0, 3, 1, 2)

        # Apply residual connection and layer normalization
        x = x + x.permute(0, 2, 3, 1)
        x = self.norm2(x)

        # Apply MLP
        y = self.mlp(x)

        # Partition into non-overlapping windows and shift
        B, N, C = y.shape
        h = self.window_size
        w = N // h
        y = y.view(B, h, w, C)
        y = y.permute(0, 3, 1, 2)
        y = torch.nn.functional.pad(y, (0, 0, 0, 0, self.shift_size // 2, self.shift_size // 2), mode="constant")
        y = y.reshape(B, C, h * (w + 2 * (self.shift_size // 2)))
        y = y.permute(0, 2, 1)

        # Apply residual connection
        y = y + x

        return y

# Swin UNet Model
class SwinUNet(nn.Module):
    def __init__(self, input_channels, num_classes):
        super().__init__()

        # Swin Transformer Encoder Blocks
        self.enc1 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.enc2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.enc3 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.enc4 = SwinTransformerEncoderBlock(embed_dim=512, num_heads=16, window_size=7, shift_size=0,mlp_ratio=4.0)

        # Swin Transformer Decoder Blocks
        self.dec1 = SwinTransformerEncoderBlock(embed_dim=256, num_heads=8, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.dec2 = SwinTransformerEncoderBlock(embed_dim=128, num_heads=4, window_size=7, shift_size=0,mlp_ratio=4.0)
        self.dec3 = SwinTransformerEncoderBlock(embed_dim=64, num_heads=2, window_size=7, shift_size=0,mlp_ratio=4.0)

        # Final convolutional layer for segmentation output
        self.final_conv = nn.Conv2d(96, num_classes, kernel_size=1, stride=1, padding=0)

        # Modify the input channel of the first convolutional layer
        self.conv_input = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Encoder
        x = self.conv_input(x)
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.enc3(x2)
        x4 = self.enc4(x3)

        # Decoder
        y1 = self.dec1(x4) + nn.functional.interpolate(x4, scale_factor=2, mode="nearest")
        y2 = self.dec2(y1) + nn.functional.interpolate(y1, scale_factor=2, mode="nearest")
        y3 = self.dec3(y2) + nn.functional.interpolate(y2, scale_factor=2, mode="nearest")

        # Output
        y = self.final_conv(y3)
        return y


##Config_clusters

In [None]:
import os
import shutil
from datetime import datetime
import pytz

class models_genesis_config:
    model = "Unet2D"
    suffix = "genesis_chest_ct"
    exp_name = model + "-" + suffix
    
    # data
    data = "/mnt/dataset/shared/zongwei/LUNA16/Self_Learning_Cubes" # not use
    scale = 32
    input_rows = 256
    input_cols = 256
    input_deps = 1
    nb_class = 1

    # image deformation
    nonlinear_rate = 0.9
    paint_rate = 0.9
    outpaint_rate = 0.8
    inpaint_rate = 1.0 - outpaint_rate
    local_rate = 0.5
    flip_rate = 0.4
    
    # logs
    # model_dir = "../SSLModel/Reuslts/pretrained_weights"
    model_dir = "/content/drive/MyDrive/Spring_research_2023/SSLModel/Reuslts/pretrained_weights"
    timenow = datetime.strftime(datetime.now(pytz.timezone('Asia/Singapore')), '%Y-%m-%d_%H-%M-%S')
    model_path = os.path.join(model_dir,timenow)
    print('Model path: ',model_path)
    if not os.path.exists(model_path):
        os.makedirs(model_path)
        
    logs_path = os.path.join(model_path, "Logs")
    print('log path: ',logs_path)
    if not os.path.exists(logs_path):
        os.makedirs(logs_path)
        
    shotdir = os.path.join(model_path, 'snapshot')
    print('snapshot path: ',shotdir)
    if not os.path.exists(shotdir):
        os.makedirs(shotdir)
    
    # model pre-training
    verbose = 1
    weights = os.path.join(model_path,'ISIC_Unsup.pt')
    batch_size = 1
    optimizer = "sgd"
    workers = 10
    max_queue_size = workers * 4
    save_samples = "png"
    nb_epoch = 10000
    patience = 100
    lr = 0.01
    
    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")


##Data Loader

In [None]:
from torch.utils.data import Dataset
import os
import numpy as np
import re
import imgaug.augmenters as iaa
#from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from torchvision import transforms
import cv2
import random
from tqdm import tqdm

import matplotlib.pyplot as plt

random.seed(1)


def Dataset_Loader(path, img_size):
    print('\nLoading dataset...\n')
    read_imgs = np.load(path)
    rows = img_size[0]
    cols = img_size[1]
    
    images = np.ndarray((read_imgs.shape[0], read_imgs.shape[-1], rows, cols), dtype=float)
    for i in range(read_imgs.shape[0]):
        img = cv2.resize(read_imgs[i, 0], (cols, rows), interpolation=cv2.INTER_CUBIC)
        images[i, 0, :, :] = img/255.
    return images

if __name__ == "__main__":
    data_path_train = '/content/drive/MyDrive/Spring_research_2023/data/GrayData'
    trainpath = data_path_train + '/imgs_train.npy'
    
    dataset = Dataset_Loader(trainpath,[256,256])

##Pre-train

In [None]:
#!/usr/bin/env python
# coding: utf-8
# ref https://github.com/MrGiovanni/ModelsGenesis


import warnings
warnings.filterwarnings('ignore')
import numpy as np
from torch import nn
import torch
from torchsummary import summary
import sys
# from utils import *
# from unet_model2 import UNet
# from config_cluster import models_genesis_config
# from data_load import Dataset_Loader
import logging
import os

print("torch = {}".format(torch.__version__))
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

conf = models_genesis_config()
conf.display()
img_size = [256,256]

# data load

'''
in unsupervised learning,
train set = 1600
validation set = 400
test 1: fix the train set for the first 1600
pixel value scale :[0,1]
resize: 16*N  = 256
'''
train_path = '/content/drive/MyDrive/Spring_research_2023/data/GrayData/imgs_train.npy'
train_set = Dataset_Loader(train_path,img_size)

train_num =  1600 # Lakmali: 1600
valid_num =  400 # Lakmali: 400
total_num = len(train_set)
x_train = train_set[0:train_num]
x_valid = train_set[train_num:train_num+valid_num]

logging.basicConfig(filename=conf.shotdir+"/"+"snapshot.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.info(str(conf))

print("x_train: {} | {:.2f} ~ {:.2f}".format(x_train.shape, np.min(x_train), np.max(x_train)))
print("x_valid: {} | {:.2f} ~ {:.2f}".format(x_valid.shape, np.min(x_valid), np.max(x_valid)))

training_generator = generate_pair(x_train,conf.batch_size, conf)
validation_generator = generate_pair(x_valid,conf.batch_size, conf)


device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Lakmali: model = UNet(n_channels=1, n_classes=conf.nb_class).cuda()
model = UNet(n_channels=1, n_classes=conf.nb_class).cuda()
#model = SwinUNet(input_channels=1, num_classes=conf.nb_class)
model.to(device)

print("Total CUDA devices: ", torch.cuda.device_count())

summary(model, (1,conf.input_rows,conf.input_cols), batch_size=-1)
criterion = nn.MSELoss()

if conf.optimizer == "sgd":
	optimizer = torch.optim.SGD(model.parameters(), conf.lr, momentum=0.9, weight_decay=0.0, nesterov=False)
elif conf.optimizer == "adam":
	optimizer = torch.optim.Adam(model.parameters(), conf.lr)
else:
	raise

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(conf.patience * 0.8), gamma=0.5)

# to track the training loss as the model trains
train_losses = []
# to track the validation loss as the model trains
valid_losses = []
# to track the average training loss per epoch as the model trains
avg_train_losses = []
# to track the average validation loss per epoch as the model trains
avg_valid_losses = []
#best_loss = 100000
best_loss = 0.03 # Lakmali: 0.02
intial_epoch =0
num_epoch_no_improvement = 0
sys.stdout.flush()

print(conf.weights) # Lakali:

'''
# Lakmali:
file_path = conf.weights

# Save the model weights as a .pt file
torch.save(model.state_dict(), file_path)


if conf.weights != None:    
	checkpoint=torch.load(conf.weights)
	model.load_state_dict(checkpoint['state_dict'])
	optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
	intial_epoch=checkpoint['epoch']
	print("Loading weights from ",conf.weights)
sys.stdout.flush()
'''

for epoch in range(intial_epoch,conf.nb_epoch):
    scheduler.step(epoch)
    model.train()
    for iteration in range(100): # Lakmali:  for iteration in range(int(x_train.shape[0]//conf.batch_size)):
        image, gt = next(training_generator)
        gt = np.repeat(gt,conf.nb_class,axis=1)
        image,gt = torch.from_numpy(image).float().to(device), torch.from_numpy(gt).float().to(device)
        pred=model(image)
        pred=torch.sigmoid(pred)
        
        loss = criterion(pred,gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses.append(round(loss.item(), 2))
        if (iteration + 1) % 5 ==0:
            print('Epoch [{}/{}], iteration {}, Loss: {:.6f}'.format(epoch + 1, conf.nb_epoch, iteration + 1, np.average(train_losses)))
        sys.stdout.flush()

    with torch.no_grad():
        model.eval()
        print("validating....")
        for i in range(int(x_valid.shape[0]//conf.batch_size)):
            x,y = next(validation_generator)
            y = np.repeat(y,conf.nb_class,axis=1)
            image,gt = torch.from_numpy(x).float(), torch.from_numpy(y).float()
            image=image.to(device)
            gt=gt.to(device)
            pred=model(image)
            pred=torch.sigmoid(pred)
            loss = criterion(pred,gt)
            valid_losses.append(loss.item())
    
    #logging
    train_loss=np.average(train_losses)
    valid_loss=np.average(valid_losses)
    avg_train_losses.append(train_loss)
    avg_valid_losses.append(valid_loss)
    print("Epoch {}, validation loss is {:.4f}, training loss is {:.4f}".format(epoch+1,valid_loss,train_loss))
    train_losses=[]
    valid_losses=[]
    if valid_loss < best_loss:
        print("Validation loss decreases from {:.4f} to {:.4f}".format(best_loss, valid_loss))
        best_loss = valid_loss
        num_epoch_no_improvement = 0
        #save model
        torch.save({
            'epoch': epoch+1,
            'state_dict' : model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        },os.path.join(conf.model_path, "ISIC_Unsup.pt"))
        print("Saving model ",os.path.join(conf.model_path, "ISIC_Unsup.pt"))
    else:
        print("Validation loss does not decrease from {:.4f}, num_epoch_no_improvement {}".format(best_loss,num_epoch_no_improvement))
        num_epoch_no_improvement += 1
    if num_epoch_no_improvement == conf.patience:
        print("Early Stopping")
        break
    sys.stdout.flush()



