In [None]:
import pandas as pd 
import torch 
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from dataset import LeafDataset
from torchvision import transforms
from train_validate import Train_Validate
import timm 

%load_ext autoreload
%autoreload 2


In [None]:
#Define the root dir for train_images
train_images_dir = '/home/faris.almalik/Desktop/plants_diseas/train_images'
csv_path = './train.csv'

#Define params 
batch_size = 64
shuffle_train = True
shuffle_test = False
num_workers = 12
num_classes = 5
image_size = (224,224)
random_state = 42


# Explore Dataset

In [None]:
data_df = pd.read_csv(csv_path)
data_df.head()

In [None]:
_ = plt.hist(data_df.label.values, bins= 50)
plt.title('Classes Distribution')
_ = plt.xticks([* range(5)])
plt.xlabel('Classes')
plt.ylabel('Class Count')

# Get Test/Train Splits

In [None]:
train_df , test_df = train_test_split(data_df, test_size=0.1, stratify=data_df.label, random_state=random_state)
train_df.reset_index(inplace = True, drop = True)
test_df.reset_index(inplace = True, drop = True)
# train_df.to_csv('train_new.csv', index = False)
# test_df.to_csv('test_new.csv', index = False)


In [None]:
fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (8,4), sharey = True, sharex= True)
_ = ax[0].hist(train_df.label.values, bins= 50)
_ = ax[0].set_xticks([* range(5)])
_ = ax[0].set_ylabel('Class Count')
ax[0].set_title('Training')
_ = ax[1].hist(test_df.label.values, bins= 50)
_ = ax[1].set_xticks([* range(5)])
ax[1].set_title('Testing')


In [None]:
# Get the number of images in each class in the training set. 
classes_counts_dict= pd.value_counts(train_df.label).to_dict()
print(f'Classes Counts \n {classes_counts_dict}')

classes_counts_sorted = {i : classes_counts_dict[i] for i in range(num_classes)}
print(f'Sorted Classes Counts \n {classes_counts_sorted}')

In [None]:
sum_samples = np.array(list(classes_counts_sorted.values())).sum()
class_weights = 1./(np.array(list(classes_counts_sorted.values()))/sum_samples)
class_weights = class_weights / class_weights.sum()
class_weights

# Datasets and Dataloaders

In [None]:
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(p = 0.5),
    transforms.RandomResizedCrop(size = image_size, scale = (0.95,1.0)),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomAffine(degrees=2),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([  
    transforms.ToPILImage(),
    transforms.Resize(image_size),
    transforms.ToTensor()
])

train_dataset = LeafDataset(
    csv_file = './train.csv' , 
    root_dir = train_images_dir, 
    transform = train_transform, 
    mode= 'train', 
    random_state= random_state)

test_dataset = LeafDataset(
    csv_file = './train.csv' , 
    root_dir = train_images_dir, 
    transform = test_transform, 
    mode= 'test', 
    random_state= random_state)

train_dataloader = DataLoader(
    dataset= train_dataset, 
    batch_size = batch_size, 
    shuffle = shuffle_train, 
    num_workers= num_workers, 
    drop_last= True
)

test_dataloader = DataLoader(
    dataset= test_dataset, 
    batch_size = batch_size, 
    shuffle = shuffle_test, 
    num_workers= num_workers, 
    drop_last= True, 
)

In [None]:
plt.imshow(test_dataset[200][0].numpy().transpose(1,2,0))

# Train the model

In [None]:
#Define Criterion and Optim 
lr = 2e-5
epochs = 1
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype = torch.float).cuda())
# criterion = torch.nn.CrossEntropyLoss()
# model = torchvision.models.resnet50(pretrained=True, progress=True)
# model.fc = torch.nn.Linear(in_features= model.fc.in_features, out_features=num_classes)
model = timm.create_model('vit_base_patch16_224_in21k', pretrained =True, num_classes = 5).cuda()
optimizer = torch.optim.Adam(params= model.parameters(), lr = lr)

In [None]:
model_object = Train_Validate(
    model = model, train_loader = train_dataloader,
    test_loader= test_dataloader, epochs = epochs, optimizer = optimizer,
    criterion = criterion, device = 'cuda' if torch.cuda.is_available() else 'cpu'
    )

In [None]:
#train the model 
train_acc, train_loss, f1_train = model_object.fit_model()

In [None]:
plt.plot(train_acc, '--*', color = 'yellow', label = 'Train ACC', alpha = 0.7)
plt.plot(train_loss, '--o', color = 'black', label = 'Train Loss')
plt.plot(f1_train, '-x', color = 'green', label = 'Train F1_Score', alpha = 0.2)
plt.legend()
plt.grid()
plt.xticks([i for i in range(1,epochs)])
plt.xlabel('Epochs')
plt.show()

In [None]:
model_object.evaluation()