In [32]:
import json
import os
import torch
import torchvision
import datetime


from torchvision.models import resnet18
from torchvision.transforms import v2  
from torch.utils.data import DataLoader
from FaceDataset import FaceDataset



from models import *

In [33]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))


In [34]:
DATASET_PATH = "../dataset/"
TRAIN_PATH = DATASET_PATH + "/train/"
VALIDATION_PATH = DATASET_PATH + "/validation/"
landmarks = json.load(open(DATASET_PATH + 'data.json'))

In [35]:

transforms = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [36]:
import torch

from torch.utils.data import Dataset
from torchvision.io import read_image

class FaceDataset(Dataset):

    def __init__(self, image_names: list[str], image_path: str, labels: dict[str: list], device=torch.device('cpu'), transforms=None):
        self.transforms = transforms
        self.image_names = image_names
        self.image_path = image_path
        self.labels = labels
        self.device = device


    def __getitem__(self, index):
        
        image = read_image(self.image_path + self.image_names[index]).to(dtype=torch.float)
        label = self.labels[self.image_names[index]]
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label)
        #label = label.to(device=self.device)
        label = label.reshape(1, -1).squeeze()
        if self.transforms:
            image = self.transforms(image / 255)

        return image, label
    
    def __len__(self):
        return len(self.image_names)

In [37]:
train_images = os.listdir(TRAIN_PATH)
validation_images = os.listdir(VALIDATION_PATH)

train_dataset =  FaceDataset(train_images, TRAIN_PATH, landmarks, device, transforms=transforms)
validation_dataset = FaceDataset(validation_images, VALIDATION_PATH, landmarks, device, transforms=transforms)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=8, shuffle=True)

In [38]:
train_dataset[0][1].shape

torch.Size([136])

In [39]:
def training_loop(n_epochs, optimizer, scheduler, model, loss_fn, train_loader, validation_loader, ):
    best_score = 0
    best_epoch = 0
    counter = 0
    
    for epoch in range(n_epochs):
        model.train()
        for data, label in train_loader:
            data = data.to(device=device)
            label = label.to(device=device)
            output = model(data)
            #print(output.shape, label.shape)
            loss = loss_fn(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        model.eval()
        summary_loss = 0
        with torch.no_grad():
            for i, (data, label) in enumerate(validation_loader):
                data = data.to(device=device)
                label = label.to(device=device)
                output = model(data)
                
                summary_loss += loss_fn(output, label)

        validation_score = summary_loss/i


        if best_score != 0 and best_score < validation_score:
            counter += 1
            if counter == 5:
                print(f"Early stop on epoch {epoch}")
                print(f"Weights are loaded from epoch {best_epoch}")
                model.load_state_dict(best_weights)
                break
        else:
            counter = 0
            best_epoch = epoch
            best_score = validation_score
            best_weights = model.state_dict()

        #if epoch == 1 or epoch % 5 == 0:

        print('{} Epoch {}, Training loss {}, Validation loss {}, lr {}'.format(
            datetime.datetime.now(),
            epoch,
            loss / len(train_loader),
            validation_score,
            scheduler.get_last_lr())
        )

In [40]:
class ModifiedResnet(torch.nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.model = resnet18(weights=weights)
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, 68 * 2)

    def forward(self, x):
        out = self.model(x)
        #return out.reshape(-1, 2)
        return out

In [41]:
model = ModifiedResnet18().to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = torch.nn.SmoothL1Loss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=20, gamma=0.2)
training_loop(n_epochs=200, optimizer=optimizer, model=model, loss_fn=loss_fn, scheduler=scheduler,
              train_loader=train_dataloader, validation_loader=validation_dataloader)
torch.save(model.state_dict(), "resnet18_weights.pt")

2024-06-28 14:32:23.562110 Epoch 0, Training loss 0.020378483459353447, Validation loss 11.605183601379395, lr [0.001]
2024-06-28 14:34:04.167827 Epoch 1, Training loss 0.03936924785375595, Validation loss 11.619023323059082, lr [0.001]
2024-06-28 14:35:42.412578 Epoch 2, Training loss 0.01916358806192875, Validation loss 10.04627799987793, lr [0.001]
2024-06-28 14:37:22.632751 Epoch 3, Training loss 0.01661578193306923, Validation loss 8.338556289672852, lr [0.001]
2024-06-28 14:39:02.470662 Epoch 4, Training loss 0.014667359180748463, Validation loss 6.046509265899658, lr [0.001]
2024-06-28 14:40:46.945296 Epoch 5, Training loss 0.019666526466608047, Validation loss 6.2515153884887695, lr [0.001]
2024-06-28 14:42:30.681035 Epoch 6, Training loss 0.015752624720335007, Validation loss 5.361302375793457, lr [0.001]
2024-06-28 14:44:08.948573 Epoch 7, Training loss 0.01403119321912527, Validation loss 4.727712154388428, lr [0.001]
2024-06-28 14:45:43.202096 Epoch 8, Training loss 0.00927