In [1]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [2]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl


Collecting torch-xla==1.9
  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl (149.9 MB)
[K     |████████████████████████████████| 149.9 MB 26 kB/s 
Installing collected packages: torch-xla
  Attempting uninstall: torch-xla
    Found existing installation: torch-xla 1.9
    Uninstalling torch-xla-1.9:
      Successfully uninstalled torch-xla-1.9
Successfully installed torch-xla-1.9.1


In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.models as visionmodels
import torchvision.transforms as transforms
import torch.utils.data.dataloader as DataLoader
import torch.utils.data.dataset as Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable,grad
from torchvision import datasets
from collections import *
import os
from torchvision.io import read_image
from datetime import datetime
import glob
from zipfile import ZipFile

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.utils.utils as xu




In [4]:
from google.colab import drive
drive.mount('/content/gdrive')

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
zip_loc = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_small.zip'
with ZipFile(zip_loc, 'r') as zf: 
  zf.extractall('data/cats')


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [5]:
FLAGS = {}
FLAGS['datadir'] = 'data/cats'
FLAGS['labeldir'] = '/content/gdrive/MyDrive/Colab Notebooks/data/cats_small.csv'
FLAGS['batch_size'] = 4
FLAGS['num_workers'] = 4
FLAGS['learning_rate'] = .0001
FLAGS['num_epochs'] = 200
FLAGS['num_cores'] = 8
FLAGS['noise_dim'] = 3
FLAGS['noise_scale_factor'] = .01
FLAGS['loss_scale_lambda'] = 2e6
FLAGS['im_size'] = 512
FLAGS['val_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/cats/val_cat.jpg' 
FLAGS['style_image_path'] = '/content/gdrive/MyDrive/Colab Notebooks/Style Transfer/data/monet/07.jpg' 
FLAGS['val_size_mult'] = 2

In [6]:
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt
from IPython import display 

from google.colab.patches import cv2_imshow
import cv2
    
RESULT_IMG_PATH = '/tmp/test_result.png'

def plot_results(image,epoch):
    fig = plt.figure(figsize= (5,5), dpi = 200) 
    ax = fig.subplots()
    ax.imshow(image[0].permute(1,2,0))
    plt.savefig('/tmp/test_result_{}.png'.format(epoch), transparent=True)
    plt.show(block = False)
def display_results(epoch):
    img = cv2.imread('/tmp/test_result_{}.png'.format(epoch), cv2.IMREAD_UNCHANGED)
    cv2_imshow(img)

In [7]:
class ConvLayer(nn.Module): 
  def __init__(self,in_channels, out_channels): 
    super().__init__() 
    self.main = nn.Sequential(
                              nn.Conv2d(in_channels, out_channels, kernel_size = 3, 
                                        stride = 1,padding =1, padding_mode = 'reflect', bias = True), 
                              nn.BatchNorm2d(out_channels), 
                              nn.LeakyReLU(.2,inplace= False), 
                              nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                                        stride = 1,padding =1, padding_mode = 'reflect', bias = True), 
                              nn.BatchNorm2d(out_channels), 
                              nn.LeakyReLU(.2,inplace= False), 
                              nn.Conv2d(out_channels, out_channels, kernel_size = 1, 
                                        stride = 1, bias = True), 
                              nn.BatchNorm2d(out_channels), 
                              nn.LeakyReLU(.2,inplace= False)     
                              )
  def forward(self,input): 
    return self.main(input) 

class Join(nn.Module): 
  def __init__(self,in_channel_1, in_channel_2):
    super().__init__()
    out_channels = in_channel_1 + in_channel_2
    self.up_branch = nn.Sequential(
                                   nn.UpsamplingNearest2d(scale_factor= 2), 
                                   nn.BatchNorm2d(in_channel_1)     
                                  )
    self.down_branch = nn.BatchNorm2d(in_channel_2)
  def forward(self, in1, in2):
    out1 = self.up_branch(in1)
    out2 = self.down_branch(in2)
    return torch.cat((out1,out2),dim=1) 

class ConvJoinBlock(nn.Module): 
  def __init__(self, in_channels_1, out_channels_1, in_channels_2, out_channels_2): 
    super().__init__()
    self.conv_upper = ConvLayer(in_channels_1,out_channels_1)
    self.conv_lower = ConvLayer(in_channels_2,out_channels_2) 
    self.join = Join(out_channels_1, out_channels_2) 

  def forward(self,in1, in2): 
    out1 = self.conv_upper(in1)
    out2 = self.conv_lower(in2)
    out = self.join(out1,out2)
    return out

class StyleTransferGenerator(nn.Module): 
  def __init__(self, num_layers = 6, max_im_size = 512,im_channels = 3,noise_channels = 3, feat_maps = 8): 
    super().__init__() 
    self.num_layers = num_layers
    self.layers = nn.ModuleList()
    branch_feat_maps = 0
    for i in range(num_layers-1): 
      branch_feat_maps += feat_maps
      if i == 0: 
        self.layers.append(ConvJoinBlock(im_channels,branch_feat_maps,im_channels,feat_maps))
      else: 
        self.layers.append(ConvJoinBlock(branch_feat_maps,branch_feat_maps,im_channels, feat_maps)) 
      
    branch_feat_maps += feat_maps
      
    self.final_block = nn.Sequential(
                                    ConvLayer(branch_feat_maps , branch_feat_maps), 
                                    nn.Conv2d(branch_feat_maps, im_channels,kernel_size = 1, bias = True), 
                                    nn.BatchNorm2d(im_channels), 
                                    nn.LeakyReLU(.2,inplace = False)
                                    )
  

  def forward(self,input): 
    samples = self.get_downsamples(input)
    out = samples[-1]     
    for i in range(self.num_layers-1,0,-1):
      out = self.layers[self.num_layers - (i+1)](out, samples[i-1])
    return self.final_block(out)

  def get_downsamples(self,input): 
    downsample = nn.Upsample(scale_factor=.5, mode = 'bilinear') 
    samples = [input]
    for i in range(self.num_layers-1):  
      samples.append(downsample(samples[-1]))
    return samples

class Normalization(nn.Module): 
  def __init__(self,mean,std): 
    super().__init__() 
    self.mean = mean.view(1,-1,1,1)
    self.std = std.view(1,-1,1,1) 
  def forward(self,x):  
    return (x-self.mean)/self.std 

class FeatureNetwork(nn.Module): 
  def __init__(self,vgg_features): 
    super().__init__()
    self.features = vgg_features #feature layers of a vgg19_bn network
    self.content_layers = [32]
    self.style_layers = [2,9,26,23,30,42]
    self.mean = torch.tensor([0.485, 0.456, 0.406])
    self.std = torch.Tensor([0.229, 0.224, 0.225])
    self.norm = Normalization(self.mean,self.std)
  def forward(self,input): 
    input = self.norm(input)
    content = []
    styles = [] 
    for i,module in enumerate(self.features): 
      input = module(input) 
      if i in self.content_layers: 
        content.append(input)
      elif i in self.style_layers: 
        styles.append(input)
    return content, styles


In [8]:
class customDataset(Dataset.Dataset):
    def __init__(self, annotations_file, img_dir, transform=None,target_transform = None):
        '''
            we just want an unlabelled image dataset
        '''
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return (len(self.img_labels))

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        image = read_image(img_path)/255

        if self.transform:
            image = self.transform(image)
        return image

def _initialize_dataset(data_path = None, label_path = None, data_transform = None,target_transform = None):
    dataset = customDataset(label_path, data_path, transform=data_transform,target_transform = target_transform)
    # training_set = torch.utils.data.DataLoader(
    #      dataset, batch_size=FLAGS['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
    return dataset




# Need to rewrite training loop for XMA package

In [9]:
def init_data():
  im_size = FLAGS['im_size']
  source_data_dir = FLAGS['datadir']
  annotation_file = FLAGS['labeldir']
  compose = transforms.Compose([ transforms.Resize(im_size), transforms.CenterCrop((im_size,im_size))])
  dataset = _initialize_dataset(data_path = source_data_dir, label_path = annotation_file,\
                              data_transform = compose)

  return dataset

In [10]:
def init_weights(m):
    if isinstance(m,nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight)
        if m.bias.data is not None: 
          m.bias.data.fill_(0.)

In [11]:
SERIAL_EXEC = xmp.MpSerialExecutor()
# Only instantiate model weights once in memory.
vgg_net = visionmodels.vgg19_bn(pretrained=True)
#feature_net = FeatureNetwork(list(vgg_net.features))
generator = StyleTransferGenerator(num_layers = 6,im_channels = 3,noise_channels = FLAGS['noise_dim'], feat_maps = 8)
generator.apply(init_weights)

WRAPPED_GENERATOR = xmp.MpModelWrapper(generator) 
#WRAPPED_F_NET = xmp.MpModelWrapper(feature_net)
WRAPPED_VGG = xmp.MpModelWrapper(vgg_net)

def train(rank): 
  torch.manual_seed(1) 
  data = SERIAL_EXEC.run(lambda: init_data()) 
  train_sampler = torch.utils.data.distributed.DistributedSampler(
      data, 
      num_replicas = xm.xrt_world_size(), 
      rank = xm.get_ordinal(), 
      shuffle = True 
  )

  train_loader = torch.utils.data.DataLoader(
    data, 
    batch_size =FLAGS['batch_size'], 
    sampler = train_sampler, 
    num_workers = FLAGS['num_workers'], 
    drop_last = False   
  )

  num_batches = len(train_loader)
  device = xm.xla_device() 
  G = WRAPPED_GENERATOR.to(device) 
  #f_net = WRAPPED_F_NET.to(device)
  vgg = WRAPPED_VGG.to(device)
  feature_net = FeatureNetwork(list(vgg.features)).to(device)
  f_net = xmp.MpModelWrapper(feature_net).to(device)
  '''
    optimizer and disable feature net gradients
  '''
  optimizer = optim.Adam(G.parameters(), lr = FLAGS['learning_rate'], betas = (0.0,.999)) 
  for p in f_net.parameters(): 
    p.requires_grad = False

  xm.master_print('mean/std device:'.format(f_net.mean.device, f_net.std.device))

  '''
    training parameters
  '''
  noise_dim = FLAGS['noise_dim']
  noise_sf = FLAGS['noise_scale_factor']
  lmbda = FLAGS['loss_scale_lambda']
  num_epochs = FLAGS['num_epochs']
  im_size = FLAGS['im_size']
  test_mult = FLAGS['val_size_mult']
  '''
    GRAM MATRIX LOSS 
  '''
  def gram_matrix(A): 
    N,C, h, w = A.shape
    if N == 1:
      A = A.view(C,-1)
      G = torch.mm(A,A.t())
    else: 
      A = A.view(N,C,-1)
      G = torch.bmm(A,A.transpose(1,2))
    return G.div(C*h*w) #returns CxC normalized gram matrix

  def gramMSELoss(input,target): 
    #assuming target is already a gram matrix 
    G = gram_matrix(input)
    return F.mse_loss(G,target,reduction = 'mean') 

  '''
    style image and necessary transforms
  '''  
  #transform = transforms.Compose([ transforms.Resize(im_size), transforms.CenterCrop((im_size,im_size))])
  transform = transforms.RandomResizedCrop(im_size, scale = (1,1))
  val_transform = transforms.Compose([transforms.Resize(test_mult*im_size),transforms.CenterCrop((test_mult*im_size,test_mult*im_size))])
  val_image = val_transform(read_image(FLAGS['val_image_path'])/255).unsqueeze(0)
  val_input = Variable(val_image, requires_grad = False).to(device)
  fixed_noise = torch.rand(val_input.shape[0],noise_dim, val_input.shape[2], val_input.shape[3],device = device)*noise_sf
  

  def train_step(optimizer, input,device): 
    optimizer.zero_grad() 

    '''
      train net
    '''
    style_image = transform(read_image(FLAGS['style_image_path'])/255).unsqueeze(0)
    style_input = Variable(style_image,requires_grad = False).to(device)
    _,target_styles = f_net(style_input)
    for i,A in enumerate(target_styles): 
      target_styles[i] = gram_matrix(A.detach()).tile(FLAGS['batch_size'],1,1)

    nz = Variable(torch.rand(FLAGS['batch_size'],noise_dim,im_size,im_size,device = device),requires_grad = False)*noise_sf
    input += nz
    optimizer.zero_grad() 
    '''
        forward pass
    '''
    gen_output = G(input)
    '''
      generate test style and content 
    '''
    contents,styles = f_net(gen_output) 
    '''
      generate actual content
    '''
    target_content, _  = f_net(input)
    style_loss = 0
    content_loss = 0
    for style,target in zip(styles,target_styles): 
      style_loss += gramMSELoss(style,target)
    for content, target in zip(contents, target_content): 
      content_loss += F.mse_loss(content,target,reduction = 'mean')
    
    loss = style_loss*lmbda + content_loss
    loss.backward()
    xm.optimizer_step(optimizer)

    return loss.detach()
  
  def train_loop(loader): 
    tracker = xm.RateTracker() 
    G.train() 
    f_net.eval() 

    for n_batch, image in enumerate(loader): 
      if len(image) != FLAGS['batch_size']: 
        break
      input = Variable(image).to(device)  

      loss = train_step(optimizer, input, device)

    return loss
  
  for epoch in range(1,num_epochs+1): 
    loss = train_loop(pl.MpDeviceLoader(train_loader, device))
    xm.master_print("Epoch {}/{}: Loss: {}".format(epoch, num_epochs, loss))
    if (epoch+1)%10 == 0:
      with torch.no_grad(): 
        G.eval()
        xm.do_on_ordinals(plot_results, (G(val_input+fixed_noise).detach(),epoch), (0,)) 
        #xm.do_on_ordinals(display_results, epoch, (0,))
    

In [12]:
def _mp_fn(rank,flags): 
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor') 
  train(rank)

xmp.spawn(_mp_fn, args = (FLAGS,), nprocs = FLAGS['num_cores'], start_method= 'fork')

  cpuset_checked))


mean/std device:


  cpuset_checked))
  cpuset_checked))
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  cpuset_checked))
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "
  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample w

Epoch 1/200: Loss: 0.03383183106780052


Exception in device=TPU:6: Resource exhausted: From /job:tpu_worker/replica:0/task:0:
2 root error(s) found.
  (0) Resource exhausted: Ran out of memory in memory space hbm. Used 8.04G of 7.98G hbm. Exceeded hbm capacity by 60.28M.

Total hbm usage >= 8.06G:
    reserved         18.00M 
    program           7.89G 
    arguments       151.68M 

Output size 69.61M; shares 67.54M with arguments.

Program hbm requirement 7.89G:
    global            36.0K
    scoped           732.0K
    HLO temp          7.71G (47.3% utilization: Unpadded (3.64G) Padded (7.70G), 0.2% fragmentation (12.77M))
    overlays        185.91M

  Largest program allocations in hbm:

  1. Size: 512.00M
     Shape: f32[4,8,512,512]{1,0,3,2:T(4,128)}
     Unpadded size: 32.00M
     Extra memory due to padding: 480.00M (16.0x expansion)
     XLA label: %fusion.3106 = (f32[8]{0:T(256)}, f32[8]{0:T(256)}, f32[4,8,512,512]{1,0,3,2:T(4,128)}) fusion(f32[8]{0:T(256)} %p49.57, f32[8,3,3,3]{0,3,2,1:T(4,128)} %p50.58, bf16[4,

ProcessExitedException: ignored