In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random

device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

device: cuda


In [13]:
BATCH_SIZE = 32

trainset = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size = BATCH_SIZE, shuffle=True, )

Files already downloaded and verified


In [19]:
def get_color_distortion(s:float=0.5):
    """
    Function from the paper that create color distortion 
    s: float, the strength of color distortion, for CIFAR 10, the paper use 0.5
    """
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

In [27]:
train_transform = transforms.Compose([
            # make sure we're using PIL instead of tensor when doing other transform 
            transforms.ToPILImage(),
            #transforms.GaussianBlur(23, sigma=(0.1, 2.0)), # CIFAR 10 doesn't use gaussian blur
            transforms.RandomResizedCrop(size=32,scale=(0.08,0.1),ratio=(0.75,1.33)),
            transforms.RandomHorizontalFlip(p=0.5),
            get_color_distortion(),
            transforms.ToTensor(),])

In [36]:
# loop through each batch in trainset 
for data, label in trainset:
    for ind_img in range(BATCH_SIZE): # loop through each image in batch 
        par_tensor = data[ind_img]
        cur_tensor_0 = train_transform(par_tensor) # first aug
        cur_tensor_1 = train_transform(par_tensor) # second aug 
        # resize 
        cur_tensor_0, cur_tensor_1 = torch.unsqueeze(cur_tensor_0, 0), torch.unsqueeze(cur_tensor_1, 0) 
        # if this is the first image in the batch, we just concat the 2 data aug 
        if ind_img == 0:
            total_tensor = torch.cat((cur_tensor_0, cur_tensor_1), dim=0)
        # else append to the previous augmented pair in the batch 
        else:
            total_tensor = torch.cat((total_tensor, cur_tensor_0, cur_tensor_1), dim=0)
        

    # pass <total_tensor> into the model 
    # yhat = net(total_tensor)
    # calculate loss 
    # loss = compute_loss(yhat)
    ### if ind % 2 == 0: j = ind + 1
    ### 
    # loss.backward()
    # optimizer.step()
    print(total_tensor.shape)  
    break
        

torch.Size([64, 3, 32, 32])
