In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision.transforms import ToTensor, Compose, Grayscale, Resize, CenterCrop
import matplotlib.pyplot as plt
from PIL import Image
from parse import parse
from custom_loader import AgeDBDataset
import warnings
warnings.filterwarnings("ignore")

In [2]:
# hyper params
num_of_class = 102
hidden_unit = 1024
learning_rate = 1e-04
batch_size = 64
num_layers = 5
device = torch.device("cuda")

In [3]:
dataset = AgeDBDataset(
    directory = 'AgeDB/',
    transform = Compose([
        Resize(size=(64, 64)),
        CenterCrop(size=64),
        Grayscale(num_output_channels=1),
        ToTensor(),
    ]),
    device = device,
)

In [4]:
train_set, validation_set, test_set = dataset.get_loaders(
    batch_size=batch_size,
    train_size=0.8,
    test_size=0.2,
)

In [5]:
class AGEDBRnnModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(AGEDBRnnModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # Passing in the input and hidden state into the model and  obtaining outputs
        out, hidden = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        
        #Reshaping the outputs such that it can be fit into the fully connected layer
        out = self.fc(out[:, -1, :])
        return out

In [6]:
RNNmodel =  AGEDBRnnModel(
    input_size=64,
    hidden_size=hidden_unit,
    num_layers=5,
    num_classes=num_of_class,
).to(device)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(RNNmodel.parameters(), lr=learning_rate)

In [8]:
def train(model, optimizer, criterion, train_loader, num_epochs):
    total_step = len(train_loader)
    
    for epoch in range(num_epochs):
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.reshape(-1, 64, 64).to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            
            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
            if (i+1)%total_step ==0:  
                print(f"EPOCH: {epoch+1}/{num_epochs}, STEP: {i+1}/{total_step}, LOSS: {loss.item()}")

In [9]:
train(RNNmodel, optimizer, criterion, train_set, num_epochs=20) 

EPOCH: 1/20, STEP: 194/194, LOSS: 4.0491719245910645
EPOCH: 2/20, STEP: 194/194, LOSS: 4.041388034820557
EPOCH: 3/20, STEP: 194/194, LOSS: 4.037904262542725
EPOCH: 4/20, STEP: 194/194, LOSS: 4.035407543182373
EPOCH: 5/20, STEP: 194/194, LOSS: 4.033392429351807
EPOCH: 6/20, STEP: 194/194, LOSS: 4.031642913818359
EPOCH: 7/20, STEP: 194/194, LOSS: 4.030107498168945
EPOCH: 8/20, STEP: 194/194, LOSS: 4.028800964355469
EPOCH: 9/20, STEP: 194/194, LOSS: 4.027691841125488
EPOCH: 10/20, STEP: 194/194, LOSS: 4.026721000671387
EPOCH: 11/20, STEP: 194/194, LOSS: 4.025838851928711
EPOCH: 12/20, STEP: 194/194, LOSS: 4.025030612945557
EPOCH: 13/20, STEP: 194/194, LOSS: 4.0242838859558105
EPOCH: 14/20, STEP: 194/194, LOSS: 4.023584842681885
EPOCH: 15/20, STEP: 194/194, LOSS: 4.022908687591553
EPOCH: 16/20, STEP: 194/194, LOSS: 4.022242546081543
EPOCH: 17/20, STEP: 194/194, LOSS: 4.021596431732178
EPOCH: 18/20, STEP: 194/194, LOSS: 4.02099084854126
EPOCH: 19/20, STEP: 194/194, LOSS: 4.020423889160156
E

In [10]:
def eval(model, test_loader):
    with torch.no_grad():
        correct = 0
        total = 0
        for imgs, labels in test_loader:
            imgs = imgs.reshape(-1, 64, 64).to(device)
            labels = torch.as_tensor(labels['age']).to(device)
            outputs = model(imgs)
            
            _, pred= torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (pred == labels).sum().item()
            
        print(f"ACC: {(100*correct)/total}%")

In [11]:
eval(RNNmodel, test_set)

ACC: 2.1921341070277243%


In [12]:
eval(RNNmodel, train_set)

ACC: 2.4339136041263703%
