In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data.dataloader import DataLoader

import copy
import argparse
import os
import logging
import sys
from tqdm import tqdm
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

#rom torch_snippets import Report
#from torch_snippets import *

logger=logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))


In [None]:
def test(model, test_loader, criterion):
    model.eval()
    running_loss=0
    running_corrects=0
    
    for inputs, labels in test_loader:
        outputs=model(inputs)
        loss=criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    total_loss = running_loss // len(test_loader)
    total_acc = running_corrects.double() // len(test_loader)
    
    logger.info(f"Testing Loss: {total_loss}")
    logger.info(f"Testing Accuracy: {total_acc}")

In [None]:
def train(model, train_loader, validation_loader, criterion, optimizer):
    epochs=5
    best_loss=1e6
    image_dataset={'train':train_loader, 'valid':validation_loader}
    loss_counter=0
    #log = Report(epochs)

    for epoch in range(epochs):
        for phase in ['train', 'valid']:
            if phase=='train':
                model.train()
            else:
                model.eval()
            running_loss = 0.0
            running_corrects = 0

            for pos,(inputs, labels) in enumerate(image_dataset[phase]):
                tot=len(image_dataset[phase])
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase=='train':
                    optimizer.zero_grad()
                    loss.backward()
                    #log.record(pos=(pos+1)/tot, train_loss=loss, end='\r') # impersistent data
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_dataset[phase])
            epoch_acc = running_corrects / len(image_dataset[phase])
            
            if phase=='valid':
                if epoch_loss<best_loss:
                    best_loss=epoch_loss
                else:
                    loss_counter+=1
                with torch.no_grad():
                    for pos,(inputs, labels) in enumerate(image_dataset[phase]):
                        tot=len(image_dataset[phase])
                        outputs = model(inputs)
                        valid_loss = criterion(outputs, labels)
                        #log.record(pos=(pos+1)/tot, valid_loss=valid_loss, end='\r') # impersistent data

        if loss_counter==1:
            break
        if epoch==0:
            break
    return model

In [None]:
def net():
    model = models.resnet50(pretrained=True)

    for param in model.parameters():
        param.requires_grad = False   

    model.fc = nn.Sequential(
                   nn.Linear(2048, 128),
                   nn.ReLU(inplace=True),
                   nn.Linear(128, 5))
    return model

def create_data_loaders(data, batch_size):

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ])

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        ])
    
    dataset = ImageFolder("train_data", transform = train_transform)
    
    train_data, test_data, validation_data = torch.utils.data.random_split(dataset, [5221, 2611, 2609])
    
    train_data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
    test_data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers = 4)
    validation_data_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=True, num_workers=4)
    
    return train_data_loader, test_data_loader, validation_data_loader

In [None]:
batch_size=2
learning_rate=1e-4
train_loader, test_loader, validation_loader=create_data_loaders('train_data',batch_size)
model=net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)

logger.info("Starting Model Training")
model=train(model, train_loader, validation_loader, criterion, optimizer)
torch.save(model.state_dict(), 'TrainedModels/model.pth')
print('saved')