In [1]:
from platform import python_version
print(python_version())

3.8.6


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkrislara[0m (use `wandb login --relogin` to force relogin)


True

In [3]:
import random
import math
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import time
from PIL import Image
import torch.utils.data as data
import os
import numpy as np
from pytorchtools import EarlyStopping
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage.metrics import structural_similarity as ssim

torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

device =torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)

cuda:0
1.4.0


In [4]:
def input_transform(size,scale):
  return transforms.Compose([
                    transforms.Resize(size//scale,Image.BICUBIC),
                    transforms.ToTensor(),           
  ])
def target_transform():
  return transforms.ToTensor()

def load_img(path):
  img=Image.open(path)
  yuv=img.copy()
  yuv=yuv.convert('YCbCr')
  y,u,v=yuv.split()
  img=np.asarray(img)
  return y,u,v,img

#inherit dataset for collective dataset
class SRDataset(data.Dataset):
  def __init__(self,root_dir,input_transform=None,target_transform=None,
               fetch="train"):
    self.root_dir=root_dir
    self.input_transform=input_transform
    self.target_transform=target_transform
    self.fetch=fetch
    self.image_file_names=os.listdir(os.path.join(self.root_dir,
                                                  self.fetch.lower(),"HR"))

  def __len__(self):
    return len(self.image_file_names)
  
  def __getitem__(self,idx):
    if torch.is_tensor(idx):
      idx=idx.tolist()
    
    self.image_file_names.sort()
    HR,u,v,HRimg=load_img(os.path.join(self.root_dir,self.fetch.lower(),
                                       "HR",self.image_file_names[idx]))
    LR=HR.copy()
    if self.input_transform:
      LR=self.input_transform(LR)
      u=self.input_transform(u)
      v=self.input_transform(v)
    if self.target_transform:
      HR=self.target_transform(HR)
    sample={'LR':LR,'HR':HR}
    
   
    if self.fetch=='val':
      return sample,u,v,HRimg
    else:
      return sample

In [5]:
path='/home/htic/SRDataset'

In [6]:
config =dict(
    epochs=1000,
    momentum=0.9,
    batch_size=2,
    learning_rate=0.0001,
    gradient_clip=1.0,
    dataset="SRdata",
    architecture="SRHW"
)

In [7]:
#model pipeline
def model_pipeline(config=config):
  with wandb.init(project="SRHW",config=config):
    config=wandb.config
    model,train_loader,steps_per_epoch,val_loader,criterion,optimizer=make(config)
    print("Model Created")
    early_stopping_patience=100
    ckpt='checkpoint.pt'
    model=train(model,train_loader,steps_per_epoch,val_loader,criterion,optimizer,
                early_stopping_patience,config,schedule=False,retrain=True,ckpt=ckpt)
    return model

In [8]:
#make
def make(config):
  train_data,train_loader,steps_per_epoch=get_data(path,input_transform(128,2),
                                   target_transform(),'train',batch_size=config.batch_size)
  val_data,val_loader,_=get_data(path,input_transform(256,2),
                               target_transform(),'val')

  model=SRHW().to(device)

  criterion=nn.L1Loss()
  optimizer= torch.optim.Adam(model.parameters(),lr=config.learning_rate)

  return model,train_loader,steps_per_epoch,val_loader,criterion,optimizer

In [9]:
#get_data and make_loader
def get_data(dir,input_transform,target_transform,fetch="train",
             batch_size=2,shuffle=True,num_workers=0):
  dataset=SRDataset(root_dir=dir, input_transform=input_transform,
                    target_transform=target_transform,fetch=fetch)
  if fetch=='val':
    batch_size=len(dataset)
  loader=data.DataLoader(dataset,batch_size=batch_size,shuffle=shuffle,
                         num_workers=num_workers)
  steps_per_epoch=len(dataset)//batch_size
  return dataset,loader,steps_per_epoch

In [10]:
#model
class SRHW(nn.Module):
    def __init__(self,upscale=2):
        super(SRHW,self).__init__()
        self.Conv1=nn.Conv2d(1,32,3,padding=(1,1),bias=False)
        nn.init.uniform_(self.Conv1.weight)
        self.DWConv1=nn.Conv2d(32,32,(1,5),padding=(0,2),groups=32,bias=False)
        nn.init.uniform_(self.DWConv1.weight)
        self.PWConv1=nn.Conv2d(32,16,1,bias=False)
        nn.init.uniform_(self.PWConv1.weight)
        self.DWConv2=nn.Conv2d(16,16,(1,5),padding=(0,2),groups=16,bias=False)
        nn.init.uniform_(self.DWConv2.weight)
        self.PWConv2=nn.Conv2d(16,32,1,bias=False)
        nn.init.uniform_(self.PWConv2.weight)
        self.DWConv3=nn.Conv2d(32,32,3,padding=(1,1),groups=32,bias=False)
        nn.init.uniform_(self.DWConv3.weight)
        self.PWConv3=nn.Conv2d(32,16,1,bias=False)
        nn.init.uniform_(self.PWConv3.weight)
        self.DWConv4=nn.Conv2d(16,16,3,padding=(1,1),groups=16,bias=False)
        nn.init.uniform_(self.DWConv4.weight)
        self.PWConv4=nn.Conv2d(16,upscale**2,1,bias=False)
        nn.init.uniform_(self.PWConv4.weight)
        self.PS=nn.PixelShuffle(upscale)
        self.relu=nn.ReLU()
    
    def forward(self,x):
        x=self.Conv1(x)
        res=self.relu(x)
        res=self.relu(self.PWConv1(self.DWConv1(res)))
        res=self.PWConv2(self.DWConv2(res))
        x=x+res
        x=self.relu(x)
        x=self.relu(self.PWConv3(self.DWConv3(x)))
        x=self.PWConv4(self.DWConv4(x))
        x=self.PS(x)
        return x

In [11]:
#train with early stopping
def train(model,train_loader,steps_per_epoch,val_loader,criterion,optimizer,patience,
          config,schedule=True,retrain=False,ckpt='checkpoint.pt'):
  # train_losses =[ ]
  valid_losses =[ ]
  # avg_train_losses=[]
  # avg_valid_losses=[]
  early_stopping=EarlyStopping(patience=patience,verbose=True,
                               path=os.path.join(path,'ckpt','checkpoint.pt'))
  wandb.watch(model,criterion,log="all",log_freq=10)
  example_ct = 0  # number of examples seen
  batch_ct = 0
  batch_ct_val=0
  example_ct_val=0
  epoch=0
  if retrain:
    model.load_state_dict(torch.load(os.path.join(path,'ckpt',ckpt)))
  if schedule:
    print('scheduling learning rate......')
    scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=50,gamma=1.1)
  while(True):
    
    model.train()
    for sample in train_loader:
      loss=train_batch(sample['LR'],sample['HR'],model,
                       optimizer,criterion,config)
      # train_losses.append(loss.item())
      example_ct+=len(sample['LR'])
      batch_ct+=1
      if ((batch_ct + 1) % 25) == 0:
        log(loss, example_ct,epoch,
            scheduler.get_last_lr()[0] if schedule else config.learning_rate,
            0,0,
            train=True)
      if schedule:
        scheduler.step()


    model.eval()
    for sample,u,v,HRnp in val_loader:
      v_loss,psnr,ssim=validate_batch(sample['LR'],sample['HR'],
                                      u,v,HRnp,model,criterion,epoch)
      valid_losses.append(v_loss.item()) 
      example_ct_val+=len(sample['LR'])
      batch_ct_val+=1
      if (batch_ct_val+1)%8 == 0:
        log(v_loss,example_ct_val,epoch,
            scheduler.get_last_lr()[0] if schedule else config.learning_rate,
            psnr,ssim,
            train=False)
    if schedule:
      scheduler.step()
    valid_loss = np.average(valid_losses)
    valid_losses = []
    early_stopping(valid_loss,model)
    epoch+=1
    # if early_stopping.early_stop:
    #   print("Early stopping.........")
    #   break

  model.load_state_dict(torch.load(os.path.join(path,'ckpt','checkpoint.pt')))
  return model

def validate_batch(LR,HR,u,v,hnp,model,criterion,epoch):
  LR=LR.float()
  HR=HR.float()
  LR,HR=LR.to(device),HR.to(device)
  SR=model(LR)
  loss=criterion(SR,HR)
  SR=SR.to('cpu')
  HR=HR.to('cpu')
  SR=SR.squeeze(1)
  HR=HR.squeeze(1)
  out_img_y=SR.detach().numpy()
  out_HR_y=HR.detach().numpy()
  snr=psnr(out_img_y,out_HR_y)
  y=convert(out_img_y)
  out_img_y=out_img_y.transpose(1,2,0)
  out_HR_y=out_HR_y.transpose(1,2,0)
  stsim=ssim(out_HR_y,out_img_y,multichannel=True)
  u,v=np.asarray(u),np.asarray(v)
  u=convert(u)
  v=convert(v)
  out_img=Image.merge('YCbCr',[y,u,v])
  out_img=out_img.convert('RGB')
  hrimg=Image.fromarray(np.asarray(hnp[0]))
  plot(epoch,hrimg,out_img)
  return loss,snr,stsim

def psnr(SR,HR):
  diff=np.subtract(HR,SR)
  mse=np.mean(np.power(diff,2))
  return -10*math.log10(mse)

def convert(img):
  img=img*255.0
  img=img.clip(0,255)
  if len(img.shape)==3:
    out_channel=Image.fromarray(np.uint8(img[0]),mode='L')
  elif len(img.shape) == 4:
    out_channel=Image.fromarray(np.uint8(img[0,0]))
    out_channel=out_channel.resize((out_channel.size[0]*2,
                                    out_channel.size[1]*2),Image.BICUBIC)
  return out_channel

def train_batch(LR,HR,model,optimizer,criterion,config):
  LR=LR.float()
  HR=HR.float()
  LR,HR=LR.to(device),HR.to(device)
  SR=model(LR)
  loss=criterion(SR,HR)
  optimizer.zero_grad()
  loss.backward()
  nn.utils.clip_grad_norm_(model.parameters(),config.gradient_clip)
  optimizer.step()
  return loss


def plot(epoch,HR,SR):
  if ((epoch+1)%500 == 0) :
    plt.figure(epoch+1)
    plt.subplot(1,2,1,title='HR')
    imgphr=plt.imshow(HR)
    plt.subplot(1,2,2,title='SR')
    imgpsr=plt.imshow(SR)
    plt.show()

In [12]:
#train_log and val_log
def log(loss,example_ct,epoch,lr,psnr,ssim,train=True):
  loss=float(loss)
  if train:
    wandb.log({"epoch_t":epoch,"training_loss":loss,"lr":lr},step=example_ct)
#     print(f"Train_Loss after "+ str(example_ct).zfill(5)+ f" examples: {loss:.3f}"+f" lr:{lr:.6f}")
  else:
    wandb.log({"AVG_PSNR":psnr,"SSIM":ssim,"validation_loss":loss,"lr":lr})
    print(f"Val_Loss after "+ str(example_ct).zfill(5)+ f" examples: {loss:.3f}"+f" lr:{lr:.6f}"+f" PSNR:{psnr:.4f}"+f" SSIM:{ssim:.4f}")


In [None]:
#init
print('Do you want to start training?(Y/N)')
a=input()
if a.lower() == 'y':
  model=model_pipeline(config)

In [14]:
snr=0
dset='test'
lrlist=os.listdir(os.path.join(path,dset,'LR'))
lrlist.sort()
hrlist=os.listdir(os.path.join(path,dset,'HR'))
hrlist.sort()

a=0
# print('lr:',lrlist)
# print('hr:',hrlist)
with torch.no_grad():
  model=SRHW().to(device)
  model.load_state_dict(torch.load(os.path.join(path,'ckpt','checkpoint.pt')))
  model.eval()
  for i in range(len(lrlist)):
        im=Image.open(os.path.join(path,dset,'LR',lrlist[i]))
        im_hr=Image.open(os.path.join(path,dset,'HR',hrlist[i]))
    #   imlr=im_hr.resize((im_hr.size[0]//2,im_hr.size[1]//2), Image.BICUBIC)
        imlryuv=im.convert('YCbCr')
        imhryuv=im_hr.convert('YCbCr')
        im_hr_y,_,_=imhryuv.split()
        im=imlryuv.convert('YCbCr')
        y,u,v=im.split()
        out_hr_y = y.resize(im_hr_y.size,Image.BICUBIC)
        #print(y.size,u.size,v.size)
        img_to_tensor = transforms.ToTensor()
        inp = img_to_tensor(y).unsqueeze(0)
        im_hr_y=img_to_tensor(im_hr_y).unsqueeze(0)
        out_hr_y=img_to_tensor(out_hr_y).unsqueeze(0)
        inp=inp.to(device)
        print(inp.shape)
        #change channels
        SR=model(inp)
        SR.squeeze(0)
        out_y = SR.to('cpu')
        out_y =np.asarray(out_y)
        out_hr_y=np.asarray(out_hr_y)
        im_hr_y=np.asarray(im_hr_y)
#         print(out_y.shape,im_hr_y.shape)
        out_y= out_y*a + out_hr_y*(1-a)
        met=psnr(out_y,im_hr_y)
        print(met)
        snr+=met
#         out_y *= 255.0
#         out_y = out_y.clip(0, 255)
#         out_y = Image.fromarray(np.uint8(out_y[0,0]),mode='L')
#         out_img_cb = u.resize(out_y.size, Image.BICUBIC)
#         out_img_cr = v.resize(out_y.size, Image.BICUBIC)
#         out_img = Image.merge('YCbCr', [out_y, out_img_cb, out_img_cr])
#         out_img=out_img.convert('RGB')
#         out_img.save('image.png')

print('Avg PSNR',snr/len(hrlist))
print(len(hrlist))

torch.Size([1, 1, 702, 1020])
36.80428369824899
torch.Size([1, 1, 924, 1020])
26.236401234876666
torch.Size([1, 1, 678, 1020])
27.26198116579586
torch.Size([1, 1, 672, 1020])
32.65617174289202
torch.Size([1, 1, 1020, 804])
30.135707421558266
torch.Size([1, 1, 1020, 678])
35.272585863736026
torch.Size([1, 1, 678, 1020])
25.795685310083382
torch.Size([1, 1, 678, 1020])
31.368010780588868
torch.Size([1, 1, 762, 1020])
26.29437593050619
torch.Size([1, 1, 822, 1020])
30.399273266857758
torch.Size([1, 1, 678, 1020])
31.273144591776333
torch.Size([1, 1, 678, 1020])
29.48162238727616
torch.Size([1, 1, 684, 1020])
27.074382035992585
torch.Size([1, 1, 1020, 678])
29.83807717299049
torch.Size([1, 1, 678, 1020])
28.22764051427019
torch.Size([1, 1, 678, 1020])
33.82680203867727
torch.Size([1, 1, 678, 1020])
32.88919341468198
torch.Size([1, 1, 684, 1020])
31.4058701900517
torch.Size([1, 1, 768, 1020])
30.394602170747635
torch.Size([1, 1, 696, 1020])
38.48229196184463
torch.Size([1, 1, 384, 1020])
28

In [15]:
a=1
# print('lr:',lrlist)
# print('hr:',hrlist)
with torch.no_grad():
  model=SRHW().to(device)
  model.load_state_dict(torch.load(os.path.join(path,'ckpt','checkpoint.pt')))
  model.eval()
#   imlr=Image.open(os.path.join('/home/htic/index.jpeg'))
  im_hr=Image.open(os.path.join('/home/htic/4k.jpg'))
  imlr=im_hr.resize((im_hr.size[0]//2,im_hr.size[1]//2), Image.BICUBIC)
  imlryuv=imlr.convert('YCbCr')
#   imhryuv=im_hr.convert('YCbCr')
#   im_hr_y,_,_=imhryuv.split()
#   im=imlryuv.convert('YCbCr')
  y,u,v=imlryuv.split()
  out_hr_y = y.resize((y.size[0]*2,y.size[1]*2),Image.BICUBIC)
  #print(y.size,u.size,v.size)
  img_to_tensor = transforms.ToTensor()
  inp = img_to_tensor(y).unsqueeze(0)
#   im_hr_y=img_to_tensor(im_hr_y).unsqueeze(0)
  out_hr_y=img_to_tensor(out_hr_y).unsqueeze(0)
  inp=inp.to(device)
  print(inp.shape)
  #change channels
  SR=model(inp)
  SR.squeeze(0)
  out_y = SR.to('cpu')
  out_y =np.asarray(out_y)
  out_hr_y=np.asarray(out_hr_y)
#   im_hr_y=np.asarray(im_hr_y)
#         print(out_y.shape,im_hr_y.shape)
  out_y= out_y*a + out_hr_y*(1-a)
#   met=psnr(out_y,im_hr_y)
#   print(met)
#   snr+=met
  out_y *= 255.0
  out_y = out_y.clip(0, 255)
  out_y = Image.fromarray(np.uint8(out_y[0,0]),mode='L')
  out_img_cb = u.resize(out_y.size, Image.BICUBIC)
  out_img_cr = v.resize(out_y.size, Image.BICUBIC)
  out_img = Image.merge('YCbCr', [out_y, out_img_cb, out_img_cr])
  out_img=out_img.convert('RGB')
  out_img.save('image2.png')
  print('image saved')

# print('Avg PSNR',snr/len(hrlist))

torch.Size([1, 1, 1080, 1920])
image saved


In [6]:
model= SRHW().to(device)
model.load_state_dict(torch.load(os.path.join(path,'ckpt','checkpoint.pt')))
torch.save(model.state_dict(),os.path.join(path,'ckpt','checkpoint1_1.pt'),_use_new_zipfile_serialization=False)

In [7]:
print(model.Conv1.weight)

Parameter containing:
tensor([[[[-3.7673e-02,  5.1877e-02, -2.2276e-02],
          [ 6.0926e-02, -8.2014e-01,  9.4718e-02],
          [-2.5077e-02,  7.6823e-01, -6.9616e-02]]],


        [[[ 2.3889e-01,  3.6425e-01,  5.0051e-02],
          [-1.8791e-02, -6.0000e-01,  2.3152e-02],
          [ 2.1107e-01,  2.6575e-01,  1.8312e-01]]],


        [[[ 9.1050e-02,  2.2793e-01,  5.0284e-02],
          [ 7.8107e-01,  1.0090e+00, -2.8041e-01],
          [-9.5097e-02,  7.7779e-03,  4.2008e-02]]],


        [[[ 1.7408e-01,  8.9632e-02, -1.2008e-01],
          [-7.8207e-01,  5.8046e-01,  2.0605e-01],
          [ 8.5006e-01,  8.7285e-02,  2.5080e-01]]],


        [[[-9.4964e-03, -3.8138e-01,  4.8101e-01],
          [-4.6833e-01,  3.1087e-01,  2.6158e-01],
          [ 3.7710e-02,  3.7781e-01,  1.8380e-02]]],


        [[[-4.6267e-01, -2.7277e-01,  4.9907e-02],
          [ 6.1029e-01,  7.8088e-01, -3.4254e-01],
          [-7.6313e-02, -1.7453e-01,  3.6948e-02]]],


        [[[ 1.3436e-01,  6.1180e-01,

In [1]:
# clear cache
#https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27