In [296]:
import torch
import torchvision
import torchvision.transforms as transforms

import os
import matplotlib.pyplot as plt
import numpy as np

# path to the data

train_data_path = './train'
test_data_path = './test'


In [297]:
print(len(os.listdir('./train')))

100


In [298]:
# transform for mean and std

train_trans_ms = transforms.Compose([transforms.ToTensor()])


In [299]:
# applying transformation and selecting the data 

train_dataset_ms = torchvision.datasets.ImageFolder(root = train_data_path , transform= train_trans_ms )

In [300]:
# loading the dataset, need to do so in batches or else we run out of RAM

train_loader_ms = torch.utils.data.DataLoader(dataset = train_dataset_ms, batch_size = 32, shuffle=False)


In [301]:
# function for calculating std and mean

def get_mean_and_std(loader):
    mean = 0.
    std = 0.
    total_img_count = 0
    # looping thrue each batch
    for images, _ in loader:
        # number of images in batch
        images_count_in_batch = images.size(0)
        # resizeing the image tensor in the batch in order to reduce the dimensions of the tensor form 4 to 3
        images = images.view(images_count_in_batch, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_img_count += images_count_in_batch

    mean /= total_img_count
    std /= total_img_count

# return a proxy mean and std , we cant get the real one because we cant load the whole data set, so we calculate the avrage for each batch and then the avrage for all the batches 
    return mean,std



In [302]:
# returns the mean and std

mean , std = get_mean_and_std(train_loader_ms)



tensor([0.4714, 0.4700, 0.4550])
tensor([0.2398, 0.2303, 0.2324])


In [303]:
train_trans = transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(),transforms.RandomRotation(10),transforms.Normalize(mean,std)])

test_trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)])

In [304]:
train_dataset = torchvision.datasets.ImageFolder(root = train_data_path, transform=train_trans)
test_dataset = torchvision.datasets.ImageFolder(root = test_data_path, transform=test_trans)

In [305]:
def show_trans_img(dataset):
    loader = torch.utils.data.DataLoader(dataset,batch_size = 6,shuffle=True)
    batch = next(iter(loader))
    Images,lables = batch
    grid = torchvision.utils.make_grid(Images,nrow=3)
    plt.figure(figsize=(11,11))
    plt.imshow(np.transpose(grid,(1,2,0)))
    print('lables:', lables)



In [306]:
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=32,shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=32,shuffle=True)

In [307]:
def set_device():
    if torch.cuda.is_available():
        dev = 'cuda:0'
    else:
        dev = 'cpu'
    return torch.device(dev)


In [308]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
#chose a model, but set weights to None so you can train it yourself
resnet_18_model = models.resnet18(weights=None)

num_ftrs = resnet_18_model.fc.in_features
number_of_classes = len(os.listdir('./train'))
resnet_18_model.fc = nn.Linear(num_ftrs,number_of_classes)
device = set_device()
resnet_18_model = resnet_18_model.to(device)
loss_fn = nn.CrossEntropyLoss()
# lr 0.01 to 0.1 experminet whit it
# momenntum makes gradient desecnt faster
# weight_decay extra error to loss function , prevents overfiting
optimizer = optim.SGD(resnet_18_model.parameters(),lr=0.01,momentum=0.9,weight_decay=0.003)

In [309]:
def evaluate_model_on_test_set(model,test_loader):
    model.eval()
    predicted_correctly_on_epoch = 0
    total = 0
    device =set_device()

    with torch.no_grad():
        for data in test_loader:
            images , lables = data
            images = images.to(device)
            lables = lables.to(device)
            total += lables.size(0)

            outputs = model(images)

            _ , predicted = torch.max(outputs.data,1)

            predicted_correctly_on_epoch += (predicted == lables).sum().item()
    
    epoch_acc = 100.0 * predicted_correctly_on_epoch / total
    print('     -Test dataset. Got %d out of %d images correctly(%.3f%%)' % (predicted_correctly_on_epoch,total, epoch_acc))

In [310]:
def save_checkpoint(state,filename = 'model_checkpoint.pth.tar'):
    print('=> Saveing checkpoint')
    torch.save(state,filename,_use_new_zipfile_serialization=False)



In [311]:
def load_checkpoint(checkpoint):
    print('=> Loading checkpoint')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['state_dict'])
    epoch.load_state_dict(checkpoint['state_dict'])


In [314]:
def train_nn(model,train_loader,test_loader,criterion,optimizer,n_epoch):
    device = set_device()
    for epoch in range(n_epoch):
        print('Epoch number %d' % (epoch + 1))
        model.train()
        running_loss = 0.0
        running_correct = 0.0
        total = 0
        checkpoint = {'state_dict' : model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        save_checkpoint(checkpoint)
        for data in train_loader:
            images , lables = data
            images = images.to(device)
            lables = lables.to(device)
            total += lables.size(0)

            optimizer.zero_grad()

            outputs = model(images)

            _ , predicted = torch.max(outputs.data,1)

            loss = criterion(outputs,lables)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_correct += (lables == predicted).sum().item()

        epoch_loss = running_loss/len(train_loader)
        epoch_acc = 100.0 * running_correct / total

        print("         -Training dataset. Got %d out of %d images correctly(%.3f%%). Epoch loss: %.3f" % (running_correct,total,epoch_acc,epoch_loss))
        evaluate_model_on_test_set(model,test_loader)
        
    print("Finished")

In [315]:
train_nn(resnet_18_model,train_loader,test_loader,loss_fn,optimizer,50)

Epoch number 1
=> Saveing checkpoint
         -Training dataset. Got 1373 out of 13572 images correctly(10.116%). Epoch loss: 3.908
     -Test dataset. Got 84 out of 500 images correctly(16.800%)
Epoch number 2
=> Saveing checkpoint
         -Training dataset. Got 3224 out of 13572 images correctly(23.755%). Epoch loss: 2.995
     -Test dataset. Got 151 out of 500 images correctly(30.200%)
Epoch number 3
=> Saveing checkpoint
         -Training dataset. Got 4434 out of 13572 images correctly(32.670%). Epoch loss: 2.558
     -Test dataset. Got 174 out of 500 images correctly(34.800%)
Epoch number 4
=> Saveing checkpoint
         -Training dataset. Got 5420 out of 13572 images correctly(39.935%). Epoch loss: 2.247
     -Test dataset. Got 202 out of 500 images correctly(40.400%)
Epoch number 5
=> Saveing checkpoint
         -Training dataset. Got 6289 out of 13572 images correctly(46.338%). Epoch loss: 1.998
     -Test dataset. Got 168 out of 500 images correctly(33.600%)
Epoch number 6
=