In [117]:
import torch
import torchvision.transforms.v2 as T

import numpy as np
import cv2, PIL
from pandas import DataFrame
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [118]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

### Data Load

In [119]:
X_train = np.load('data/classification1/Xtrain_Classification1.npy')
y_train = np.load('data/classification1/ytrain_Classification1.npy')
X_train.shape

(6254, 2352)

### Data Reshape and Split

In [120]:
data_shape = (-1,28,28,3)
X_train = X_train.reshape(data_shape)/255.
X_train = torch.FloatTensor(np.transpose(X_train, (0, 3, 1, 2))).to(device)
X_train.shape

torch.Size([6254, 3, 28, 28])

In [121]:
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2)
X_train.shape, X_val.shape, y_train.shape, y_val.shape

(torch.Size([5003, 3, 28, 28]),
 torch.Size([1251, 3, 28, 28]),
 (5003,),
 (1251,))

### Data Augmentation

In [133]:
y_train_df = DataFrame(y_train)
y_val_df = DataFrame(y_val)

y_val_df.value_counts()

0.0    1092
1.0     159
Name: count, dtype: int64

In [123]:
norm_params=(0.5,0.5,0.5)
augmentations = T.Compose([
    T.ToPILImage(),
    T.ColorJitter(0.05,0.05,0.05,0.05),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(20),
    T.ToTensor(),
    T.Normalize(norm_params,norm_params),
])

def augment(X, y):
    aug_X = torch.empty((0,3,28,28))
    aug_y=[]
    for i, (image, label) in tqdm(enumerate(zip(X,y))):
        aug_X = torch.cat((aug_X, image.unsqueeze(0)), 0)
        aug_y.append(label)    
        if label == 1:
            aug_X = torch.cat((aug_X, image.unsqueeze(0)), 0)
            aug_y.append(label)
            for _ in range(4):
                aug_X = torch.cat((aug_X, augmentations(image).unsqueeze(0)), 0)
                aug_y.append(label)
    return aug_X, np.array(aug_y)



In [124]:
X_train, y_train = augment(X_train,y_train)
X_train.shape, y_train.shape

310it [00:00, 552.08it/s]

5003it [01:01, 81.91it/s] 


(torch.Size([8688, 3, 28, 28]), (8688,))