In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('trpakov/vit-face-expression')
model = ViTForImageClassification.from_pretrained('trpakov/vit-face-expression')
model.classifier = torch.nn.Linear(model.config.hidden_size, 3)

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5)),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5)),
])

In [29]:
full_dataset = datasets.ImageFolder(root='data2')
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

train_dataset.dataset.transform = train_transforms
val_dataset.dataset.transform = val_test_transforms
test_dataset.dataset.transform = val_test_transforms

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [31]:
print(len(train_loader.dataset))

3748


In [None]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and validation loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


In [35]:
def train(model, train_loader, optimizer, criterion, num_epochs=10):
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []
    for epoch in range(num_epochs):
        model.train()
        training_loss = 0.0
        total_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images = [Image.fromarray(img.permute(1, 2, 0).mul(255).byte().numpy()) for img in images]
            inputs = feature_extractor(images=images, return_tensors='pt')
            pixel_values = inputs['pixel_values'].to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(pixel_values).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            training_loss += loss.item()

            _ , predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            total_loss += loss.item()

        epoch_loss = training_loss / len(train_loader.dataset)
        epoch_acc = correct / total
        val_results = validate(model, val_loader, criterion)

        train_losses.append(training_loss / len(train_loader))
        train_accuracies.append(epoch_acc)
        val_losses.append(val_results[5])
        val_accuracies.append(val_results[0])

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Training Acc: {epoch_acc:.4f}, Val Acc: {val_results[0]:.4f}')

    return (train_losses, val_losses, train_accuracies, val_accuracies)


def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    y_true, y_pred = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images = [Image.fromarray(img.permute(1, 2, 0).mul(255).byte().numpy()) for img in images]
            inputs = feature_extractor(images=images, return_tensors='pt')
            pixel_values = inputs['pixel_values'].to(device)
            labels = labels.to(device)

            val_outputs = model(pixel_values).logits
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total += labels.size(0)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(val_predicted.cpu().numpy())

            val_loss += criterion(val_outputs, labels).item()
            val_correct += (val_predicted == labels).sum().item()

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    cm = confusion_matrix(y_true, y_pred)

    return [accuracy, precision, recall, f1, cm, val_loss]

In [None]:
num_epochs = 5
train_losses, val_losses, train_accuracies, val_accuracies = train(model, train_loader, optimizer, criterion, num_epochs=num_epochs)

# Plotting training and validation loss
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss vs. Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plotting training and validation accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.title('Accuracy vs. Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()