Start by importing packages and data

In [1]:
import torch
import os
import csv
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from PIL import Image 
import torchvision.datasets as dset
import torchvision.transforms as T
import random
import numpy as np
import matplotlib.pyplot as plt
import statistics
import gc

In [2]:
USE_GPU = True

dtype = torch.float16 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


In [3]:
# Data ingest

# Get all labels imported and indexed, with all whitespaces stripped out
labels_list = []
dataset = []
#import dataset to empty arrays of data and label
for (root, dirs, files) in os.walk("crop/"):
    labels_list = dirs
    for dire in dirs:
        curr_dir = root + dire
        for (rt,di,fi) in os.walk(curr_dir):
            for img in fi:
                data = {}
                data["image"] = T.functional.pil_to_tensor(Image.open(curr_dir+"/"+img))#.to(torch.float32)
                data["label"] = labels_list.index(dire)
                _,h,w = data["image"].shape
                if (h <= 800 and w <= 1440 and h >= 160 and w >= 160):
                    dataset.append(data)
                # for each file in the sub directory
                # append the image and corresponding label to dataset
    break
dataset.shuffle()

In [4]:
print(len(dataset))

9208


In [5]:
heights = []
widths = []
for i in dataset:
    c,h,w = i["image"].shape
    heights.append(h)
    widths.append(w)

#plt.bar(4000,heights)
#plt.show()
print (statistics.mean(heights))
print (statistics.median(heights))
print (statistics.multimode(heights))
print (statistics.mean(widths))
print (statistics.median(widths))
print (statistics.multimode(widths))
print (max(heights))
print (max(widths))

337.8471980886186
294.0
[207, 166, 191]
696.9323414422241
665.0
[1199]
800
1440


In [6]:
heights.sort(reverse=True)
print(heights[500:510])

[657, 657, 657, 656, 656, 656, 656, 656, 656, 656]


# height greater than 1504 or widths greater than 3104 is discarded

In [7]:
max_height = 800
max_width  = 1440

In [8]:
class Patchwork (nn.Module):
    def __init__ (self, patch_size = 16):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (max_height * max_width) // (self.patch_size ** 2)
        self.patch_dim = (self.patch_size ** 2) * 3
        self.linear_size = self.patch_dim // 2
        self.linear_layer = nn.Linear(self.patch_dim,self.linear_size,dtype=torch.float16)
        self.positional_embedding = self.get_pos_embed()
        # self.positional_embedding.requires_grad=False
    
    def get_pos_embed(self):
        positional_embedding = []
        for i in range(self.num_patches+1):
            
            positional_embedding.append([])
            for j in range(self.linear_size//2):
                j *= 2
                positional_embedding[i].append(np.sin(i/(10000 ** (j/self.patch_dim))))
                j += 1
                positional_embedding[i].append(np.cos(i/(10000 ** (j/self.patch_dim))))
            if (self.patch_dim % 2):
                j = self.patch_dim - 1
                positional_embedding[i].append(np.sin(i/(10000 ** (j/self.patch_dim))))
            positional_embedding[i] = torch.as_tensor(positional_embedding[i])
        return torch.from_numpy(np.array(positional_embedding))


    def forward(self,image):
        image = image.to(torch.float16)
        c,h,w = image.shape

        #size standardization
        patch_height = max_height - h
        patch_width = max_width - w

        patch_bottom = patch_height // 2
        if (patch_height % 2):
            patch_top = patch_bottom + 1
        else:
            patch_top = patch_bottom

        patch_right = patch_width // 2

        if (patch_width % 2):
            patch_left = patch_right + 1
        else:
            patch_left = patch_right
            

        padding = torch.nn.ZeroPad2d((patch_left,patch_right,patch_top,patch_bottom))
        image = padding(image)
        
        patches = []
            
        for i in range(self.num_patches):
            x_coor = i // (max_width // self.patch_size)
            y_coor = i - x_coor * (max_width // self.patch_size)
            patch = image[:, x_coor * self.patch_size: (x_coor + 1) * self.patch_size,y_coor * self.patch_size:(y_coor+1)*self.patch_size]
            patches.append(patch.flatten())
            
        patches = torch.stack(patches)
        patches = self.linear_layer(patches)
        
        classification_token = nn.Parameter(torch.rand(self.linear_size))
        patches = torch.vstack((classification_token,patches))
        # print(patches.shape)
        patches = patches + self.positional_embedding

        return patches

In [9]:
patch = Patchwork()
patched_dataset = []
for i in range(len(dataset)):
    img = dataset[i]["image"]
    patched_dataset.append(patch(img))
    gc.collect()
    if not (i % 100):
        print (i)
patched_dataset=torch.stack(patched_dataset)


0
100
200


KeyboardInterrupt: 