In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.models as visionmodels
import torchvision.transforms as transforms
import torch.utils.data.dataloader as DataLoader
import torch.utils.data.dataset as Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable,grad
from torchvision import datasets
from collections import *
import os
from torchvision.io import read_image
from datetime import datetime
import glob
from zipfile import ZipFile
from PIL import Image

In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

FLAGS = {} 
FLAGS['datadir'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/cats'
#FLAGS['labeldir'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/landscapes.csv'
FLAGS['labeldir'] = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_small.csv'
FLAGS['batch_size'] = 4
FLAGS['learning_rate'] = .1
FLAGS['num_epochs'] = 2000
FLAGS['im_channels'] = 3
FLAGS['noise_scale_factor'] = 0.0
FLAGS['loss_scale_lambda'] = 1.
FLAGS['im_size'] = 256
FLAGS['val_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/cats/cat-2083492_1280 - Copy (2).jpg' 
FLAGS['style_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/Paintings/man_with_hat.jpg' 
FLAGS['val_size_mult'] = 2
FLAGS['output_path'] = 'outputs/Texture_net_BN'
FLAGS['model_path'] = 'models/Texture_net_BN'
FLAGS['output_fname'] = 'test_output_'
FLAGS['model_fname'] = 'landscape_model'
FLAGS['home_dir'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer'

Mounted at /content/gdrive


In [32]:
class ConvLayer(nn.Module): 
  def __init__(self,in_channels, out_channels): 
    super().__init__() 
    self.sample = nn.Conv2d(in_channels, out_channels, kernel_size= 1, stride = 1) 
    self.main = nn.Sequential(
                              nn.Conv2d(in_channels, out_channels, kernel_size = 3, 
                                        stride = 1,padding =1, padding_mode = 'reflect'), 
                              nn.BatchNorm2d(out_channels, affine = True), 
                              nn.LeakyReLU(.2,inplace= True), 
                              nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                                        stride = 1,padding =1, padding_mode = 'reflect'), 
                              nn.BatchNorm2d(out_channels, affine = True), 
                              nn.LeakyReLU(.2,inplace= True), 
                              nn.Conv2d(out_channels, out_channels, kernel_size = 1, 
                                        stride = 1), 
                              nn.BatchNorm2d(out_channels, affine = True), 
                              nn.LeakyReLU(.2,inplace= True)  
                              )
  def forward(self,input): 
    return self.main(input)

class Join(nn.Module): 
  def __init__(self,in_channel_1, in_channel_2):
    super().__init__()
    out_channels = in_channel_1 + in_channel_2
    self.up_branch = nn.Sequential(
                                   nn.UpsamplingNearest2d(scale_factor= 2), 
                                   nn.BatchNorm2d(in_channel_1)     
                                  )
    self.down_branch = nn.BatchNorm2d(in_channel_2,affine = True)
  def forward(self, in1, in2):
    out1 = self.up_branch(in1)
    out2 = self.down_branch(in2)
    return torch.cat((out1,out2),dim=1) 

class ConvJoinBlock(nn.Module): 
  def __init__(self, in_channels_1, out_channels_1, in_channels_2, out_channels_2): 
    super().__init__()
    self.conv_upper = ConvLayer(in_channels_1,out_channels_1)
    self.conv_lower = ConvLayer(in_channels_2,out_channels_2) 
    self.join = Join(out_channels_1, out_channels_2) 

  def forward(self,in1, in2): 
    out1 = self.conv_upper(in1)
    out2 = self.conv_lower(in2)
    out = self.join(out1,out2)
    return out

class StyleTransferGenerator(nn.Module): 
  def __init__(self, num_layers = 6, max_im_size = 512,im_channels = 3, feat_maps = 8): 
    super().__init__() 
    self.num_layers = num_layers
    self.layers = nn.ModuleList()
    branch_feat_maps = 0
    for i in range(num_layers-1): 
      branch_feat_maps += feat_maps
      if i == 0: 
        self.layers.append(ConvJoinBlock(im_channels,branch_feat_maps,im_channels,feat_maps))
      else: 
        self.layers.append(ConvJoinBlock(branch_feat_maps,branch_feat_maps,im_channels, feat_maps)) 
      
    branch_feat_maps += feat_maps
      
    self.final_block = nn.Sequential(
                                    ConvLayer(branch_feat_maps , branch_feat_maps), 
                                    nn.Conv2d(branch_feat_maps, im_channels,kernel_size = 1),  

                                    nn.LeakyReLU(.2,inplace = True)
                                    )
  

  def forward(self,input): 
    samples = self.get_downsamples(input)
    out = samples[-1]  
    for i in range(self.num_layers-1,0,-1):
      out = self.layers[self.num_layers - (i+1)](out,samples[i-1] )
    return self.final_block(out)

  def get_downsamples(self,input): 
    downsample = nn.Upsample(scale_factor=.5, mode = 'nearest') 
    samples = [input]
    for i in range(self.num_layers-1):  
      samples.append(downsample(samples[-1]))
    return samples

class Normalization(nn.Module): 
  def __init__(self,mean,std): 
    super().__init__() 
    self.mean = mean.view(1,-1,1,1)
    self.std = std.view(1,-1,1,1) 
  def forward(self,x):  
    return (x-self.mean)/self.std 

class FeatureNetwork(nn.Module): 
  def __init__(self,vgg_features): 
    super().__init__()
    self.features = nn.ModuleList(vgg_features).eval() #feature layers of a vgg19_bn network
    #self.content_layers = [32,45]
    #self.style_layers = [2,9,26,23,30,42]
    self.content_layers = [22] 
    self.style_layers = [1,8,11,20,29]
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(device)
    self.norm = Normalization(mean,std)
  def forward(self,input): 
    input = self.norm(input)
    content = []
    styles = [] 
    for i,module in enumerate(self.features): 
      input = module(input) 
      if i in self.content_layers: 
        content.append(input)
      elif i in self.style_layers: 
        styles.append(input)
    return content, styles


In [33]:


def plot_image(input_test,test_image,input_val, val_image, style_image, save_result: bool = True): 
  fig = plt.figure(figsize = (5,5), dpi = 200) 
  ax = fig.subplots(3,2)
  ax[2,1].imshow(test_image[0].detach().cpu().permute(1,2,0))
  ax[2,1].set_title('Stylized Test Image')
  ax[1,1].imshow(val_image[0].detach().cpu().permute(1,2,0))
  ax[1,1].set_title('Stylized Validation Image')
  ax[0,1].imshow(style_image[0].detach().cpu().permute(1,2,0))
  ax[0,1].set_title('Style Image')
  ax[1,0].imshow(input_val[0].detach().cpu().permute(1,2,0))
  ax[1,0].set_title('Validation Image')
  ax[2,0].imshow(input_test[0].detach().cpu().permute(1,2,0))
  ax[2,0].set_title('Test Image')
  if save_result: 
    os.chdir(FLAGS['home_dir'])
    try: 
      os.chdir(FLAGS['output_path']) 
    except: 
      for dir in FLAGS['output_path'].split('/'):
        try:
          os.chdir(dir)
        except: 
          os.mkdir(dir) 
          os.chdir(dir)
    fname = FLAGS['output_fname'] 
    num_imgs = len(glob.glob('*.png')) 
    fname += '_'+str(num_imgs)+'.png'
    plt.savefig(fname)
    plt.show(block = False)
    os.chdir('/content/') 
  
def save_model(model): 
  os.chdir(FLAGS['home_dir'])
  try: 
    os.chdir(FLAGS['model_path']) 
  except: 
    for dir in FLAGS['model_path'].split('/'):
      try:
        os.chdir(dir)
      except: 
        os.mkdir(dir) 
        os.chdir(dir)
  fname = FLAGS['model_fname']
  num_models = len(glob.glob('*.model')) 
  fname += '_'+str(num_models)+'.model'
  torch.save(model.state_dict(), fname)
  os.chdir('/content/')

In [34]:
zip_loc = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_small.zip'
with ZipFile(zip_loc, 'r') as zf: 
  zf.extractall('data/cats')


  

In [35]:
class customDataset(Dataset.Dataset):
    def __init__(self, annotations_file, img_dir, transform=None,target_transform = None):
        '''
            we just want an unlabelled image dataset
        '''
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return (len(self.img_labels))

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        image = read_image(img_path)/255
        if self.transform:
          image = self.transform(image)
        
        return image

def _initialize_dataset(data_path = None, label_path = None, data_transform = None,target_transform = None):
    dataset = customDataset(label_path, data_path, transform=data_transform,target_transform = target_transform)
    training_set = torch.utils.data.DataLoader(
         dataset, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    return training_set,dataset



In [36]:

def contentLoss(contents, target_content): 
  content_loss = 0 
  for content,target in zip(contents, target_content): 
    content_loss+= F.mse_loss(content,target) 
  return content_loss

###computes gram matrix
def gram_matrix(A): 
  N,C, h, w = A.shape
  if N == 1:
    A = A.view(C,-1)
    G = torch.mm(A,A.t())
  else: 
    A = A.view(N,C,-1)
    G = torch.bmm(A,A.transpose(1,2))
  return G.div(C*h*w) #returns CxC normalized gram matrix
#computs gram matrix loss based on euclidean distance
def gramMSELoss(input,target): 
  G = gram_matrix(input)
  return F.mse_loss(G,target) 
#compute Frobenius based loss
def gramFrobLoss(input,target): 
  G = gram_matrix(input) 
  return torch.linalg.norm(G-target,'fro',(1,2)).sum()
###computes styleLosses between set of style representations of the target and source
###can specify the mode between frobenius and 2-norm based
def styleLoss(styles,target_styles,mode :str = 'fro'): 
  style_loss = 0 
  for style,target in zip(styles,target_styles): 
    if mode == 'fro': 
      style_loss += gramFrobLoss(style,target) 
    elif mode == 'mse': 
      style_loss += gramMSELoss(style,target) 
  return style_loss
  
def train(feature_net,       #network for extracting features
                  generator, 
                  optimizer, 
                  scheduler, 
                  train_loader,
                  val_image,                    #content target image
                  style_image
                  ): 
  for p in feature_net.parameters(): 
    p.requires_grad = False
  style_input = Variable(style_image,requires_grad = False).to(device)
  val_input = Variable(val_image, requires_grad = False).to(device)
  _,target_styles = feature_net(style_input)
  fixed_noise = torch.rand(val_input.shape[0],FLAGS['im_channels'], FLAGS['im_size']*FLAGS['val_size_mult'], FLAGS['im_size']*FLAGS['val_size_mult'],device = device)*FLAGS['noise_scale_factor']
  val_input += fixed_noise
  val_input.clip_(0,1)
  for i,A in enumerate(target_styles): 
    target_styles[i] = gram_matrix(A.detach()).tile(FLAGS['batch_size'],1,1)
  #for i,A in enumerate(target_content):
  #  target_content[i] = A.detach().tile(training_batch_size,1,1,1)
    
  for epoch in range(FLAGS['num_epochs']): 
    for p in generator.parameters(): 
      p.requires_grad = True
    for i,image in enumerate(train_loader): 
      batch_size = len(image)
      if batch_size != FLAGS['batch_size']: 
        break
      generator.train()
      input = Variable(image).to(device)  
      nz = Variable(torch.rand(FLAGS['batch_size'],FLAGS['im_channels'],FLAGS['im_size'],FLAGS['im_size'],device = device),requires_grad = False)*.1
    
      input.clip_(0,1)
      optimizer.zero_grad() 
      '''
          forward pass
      '''
      gen_output = generator(input)
      '''
        generate test style and content 
      '''
      contents,styles = feature_net(gen_output) 
      '''
        generate actual content
      '''
      target_content, _  = feature_net(input)
      style_loss = styleLoss(styles, target_styles, mode = 'mse')*1e8
      content_loss = contentLoss(contents, target_content)*1e3
    
      
      loss = style_loss+ content_loss
      loss.backward()
      optimizer.step() 
    scheduler.step()
    if (epoch+1) % 10 == 0: 
      print('stats:{}, {},{}'.format(loss.detach(),style_loss.detach(), content_loss.detach()))
    if (epoch+1) % 10 == 0: 
      generator.eval()
      for p in generator.parameters(): 
        p.requires_grad = False
      test_out = generator(val_input)
      plot_image(input,gen_output,val_input, test_out, style_input)
    if (epoch+1)%10 == 0: 
      save_model(generator)



In [37]:
def init_weights(m):
    if isinstance(m,nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight,gain = np.sqrt(2))
        if m.bias.data is not None: 
          m.bias.data.fill_(1.)

In [None]:
if __name__ == '__main__': 
  os.chdir('/content/')
  %ls 
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  vgg_net = visionmodels.vgg19(pretrained=True).to(device)
  vgg_net.eval()
  feature_net = FeatureNetwork(list(vgg_net.features)).to(device) 
  noise_scale_factor = .01
  generator = StyleTransferGenerator(num_layers = 6,im_channels = FLAGS['im_channels'], feat_maps = 8).to(device)
  generator.apply(init_weights)
  optimizer = optim.Adam(generator.parameters(), lr = .1, betas = (.5,.999))
  def lmbda(epoch): 
    if epoch == 1000: 
      return .1 
    if epoch > 1000: 
      if epoch %200 == 0: 
        return .1
      else: 
        return 1 
    else: 
      return 1
  scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda = lmbda)

  im_size = FLAGS['im_size']
  test_mult = FLAGS['val_size_mult']

  transform = transforms.Compose([ transforms.Resize(im_size), transforms.CenterCrop((im_size,im_size))])
  val_transform = transforms.Compose([transforms.Resize(im_size*test_mult),transforms.CenterCrop((test_mult*im_size,test_mult*im_size))])

  train_loader,dataset = _initialize_dataset(data_path = FLAGS['datadir'], label_path = FLAGS['labeldir'],\
                                data_transform = transform)
  
  val_image = val_transform(read_image(FLAGS['val_image_path'])/255).unsqueeze(0)
  style_image =transforms.ToTensor()(transform(Image.open(FLAGS['style_image_path']).convert('RGB'))).unsqueeze(0)

  output = train(feature_net,       #network for extracting features
                  generator, 
                  optimizer, 
                  scheduler, 
                  train_loader,
                  val_image,                    #content target image
                  style_image
                  )
      
  # fig = plt.figure(figsize = (5,5), dpi = 200) 
  # ax = fig.subplots(2,2)
  # ax[1,1].imshow(output[0].detach().cpu().permute(1,2,0)) 
  # ax[1,0].imshow(content_image[0].detach().cpu().permute(1,2,0))
  # ax[0,1].imshow(style_image[0].detach().cpu().permute(1,2,0))
  # plt.show(block = False)

Output hidden; open in https://colab.research.google.com to view.