In [33]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.multiprocessing as mp
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import dlib
import os

In [34]:
image_transforms = transforms.Compose(
    [
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    ]
)


In [35]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
use_cuda = torch.cuda.is_available()
print(use_cuda)

True


In [36]:
from tqdm import tqdm

class AffectNetDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, transform=None, limit=None):
        """
        Args:
            images_dir (string): Directory with all the images.
            annotations_dir (string): Directory with all the annotations.
            transform (callable, optional): Optional transform to be applied on a sample.
            limit (int, optional): Limit the number of samples to load for debugging.
        """
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        filenames = [f.split('.')[0] for f in os.listdir(images_dir) if f.endswith('.jpg')]
        
        if limit is not None:
            filenames = filenames[:limit]
        
        self.filenames = filenames
        
        self.targets = []
        for f in tqdm(self.filenames, desc="Loading annotations"):
            exp_path = os.path.join(annotations_dir, f + '_exp.npy')
            self.targets.append(int(np.load(exp_path).item()))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.images_dir, self.filenames[idx] + '.jpg')
        image = Image.open(img_name)

        expression = self.targets[idx]

        if self.transform:
            image = self.transform(image)
            
        return image, expression


In [37]:
train_dataset_expr = AffectNetDataset(images_dir='/work/jiewenh/openFace/DATA/AffectNet/train_set/images',
                                 annotations_dir='/work/jiewenh/openFace/DATA/AffectNet/train_set/annotations',
                                 transform=image_transforms, limit=1000)
train_loader_expr = DataLoader(train_dataset_expr, batch_size=4, shuffle=True, num_workers=4)
test_dataset_expr = AffectNetDataset(images_dir='/work/jiewenh/openFace/DATA/AffectNet/val_set/images',
                                 annotations_dir='/work/jiewenh/openFace/DATA/AffectNet/val_set/annotations',
                                 transform=image_transforms)
test_loader_expr = DataLoader(test_dataset_expr, batch_size=4, shuffle=True, num_workers=4)
print(len(train_dataset_expr), len(test_dataset_expr))

Loading annotations: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1633.19it/s]
Loading annotations: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3999/3999 [00:02<00:00, 1616.15it/s]

1000 3999





In [80]:
import glob

class GazeDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the person folders (p00, p01, ...).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.samples = self._load_samples(data_dir)
        
    def _load_samples(self, data_dir):
        samples = []
        for person_dir in glob.glob(os.path.join(root_dir, 'p*')):
            image_files = glob.glob(os.path.join(person_dir, '*.jpg'))
            for image_path in image_files:
                gaze_file = f"{image_path[:-4]}_gaze.txt"
                if os.path.exists(gaze_file):
                    samples.append((image_path, gaze_file))
        return samples
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, gaze_path = self.samples[idx]
        image = Image.open(image_path)
        gaze = np.loadtxt(gaze_path, dtype=np.float32).flatten()  # Assuming the gaze file contains 2 rows, one for each eye
        
        if self.transform:
            image = self.transform(image)
            
        return image, torch.from_numpy(gaze)


data_dir = '/work/jiewenh/openFace/DATA/mpii_data'  # Update this path
gaze_dataset = GazeDataset(data_dir=data_dir, transform=image_transforms)
gaze_loader = DataLoader(gaze_dataset, batch_size=4, shuffle=True)

In [81]:
import numpy as np

# Assuming the rest of your GazeDataset class definition is already provided above

# Load a few batches from the DataLoader
for i, (images, gazes) in enumerate(gaze_loader):
    print(f"Batch {i+1}")
    print("Images shape:", images.shape)  # Should be [batch_size, C, H, W]
    print("Gazes shape:", gazes.shape)  # Should be [batch_size, 2] assuming 2 values for gaze per sample
    # Print gaze values for each sample in the batch
    for j, gaze in enumerate(gazes):
        print(f"Sample {j+1} gaze:", gaze.numpy())
    if i == 1:  # Print 2 batches and break
        break


Batch 1
Images shape: torch.Size([4, 3, 224, 224])
Gazes shape: torch.Size([4, 14])
Sample 1 gaze: [ 6.00e+02  3.77e+02  6.59e+02  3.74e+02  4.00e-01  6.70e-02 -9.14e-01
  5.26e+02  3.81e+02  4.56e+02  3.82e+02  2.68e-01  7.30e-02 -9.61e-01]
Sample 2 gaze: [ 7.90e+02  3.51e+02  8.47e+02  3.53e+02  3.80e-02  9.20e-02 -9.95e-01
  7.09e+02  3.51e+02  6.54e+02  3.49e+02 -9.80e-02  9.10e-02 -9.91e-01]
Sample 3 gaze: [ 6.85e+02  2.64e+02  7.52e+02  2.68e+02 -4.00e-02  3.46e-01 -9.37e-01
  6.13e+02  2.65e+02  5.44e+02  2.64e+02 -1.86e-01  3.43e-01 -9.20e-01]
Sample 4 gaze: [ 7.76e+02  3.68e+02  8.31e+02  3.67e+02 -8.70e-02  2.73e-01 -9.58e-01
  7.03e+02  3.71e+02  6.52e+02  3.68e+02 -2.08e-01  2.76e-01 -9.38e-01]
Batch 2
Images shape: torch.Size([4, 3, 224, 224])
Gazes shape: torch.Size([4, 14])
Sample 1 gaze: [ 7.35e+02  2.60e+02  8.10e+02  2.74e+02 -1.84e-01  4.19e-01 -8.89e-01
  6.51e+02  2.57e+02  5.80e+02  2.55e+02 -3.14e-01  3.93e-01 -8.64e-01]
Sample 2 gaze: [ 7.24e+02  2.40e+02  8.05e

In [82]:
from torch.utils.data.dataset import random_split

dataset_size = len(gaze_dataset)
train_size = int(dataset_size * 0.8)  # 80% for training
test_size = dataset_size - train_size  

# Split the dataset
train_dataset_gaze, test_dataset_gaze = random_split(gaze_dataset, [train_size, test_size])

train_loader_gaze = DataLoader(train_dataset_gaze, batch_size=4, shuffle=True, num_workers=4)
test_loader_gaze = DataLoader(test_dataset_gaze, batch_size=4, shuffle=False, num_workers=4)  

In [97]:
model=timm.create_model('tf_efficientnet_b0_ns', pretrained=False)
model.classifier=torch.nn.Identity()
model.load_state_dict(torch.load('./models/state_vggface2_enet0_new.pt')) 
model=model.to(device)

In [98]:
print(model)

EfficientNet(
  (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          16, eps=

In [109]:
class MLT(nn.Module):
    def __init__(self, base_model_name='tf_efficientnet_b0_ns', num_classes=8):
        super(MLT, self).__init__()
        self.base_model = timm.create_model(base_model_name, pretrained=False)
        self.base_model.classifier = nn.Identity()
        
        feature_dim = self.base_model.num_features

        self.relu = nn.ReLU()

        self.fc_emotion = nn.Linear(feature_dim, feature_dim)
        self.fc_gaze = nn.Linear(feature_dim, feature_dim)
        
        self.emotion_classifier = nn.Linear(feature_dim, num_classes)
        self.gaze_regressor = nn.Linear(feature_dim, 14)  

    def forward(self, x):
        features = self.base_model(x)

        features_emotion = self.relu(self.fc_emotion(features))
        features_gaze = self.relu(self.fc_gaze(features))
        
        emotion_output = self.emotion_classifier(features_emotion)
        gaze_output = self.gaze_regressor(features_gaze)
        
        return emotion_output, gaze_output


In [110]:
model = MLT()  

state_dict = torch.load('./models/state_vggface2_enet0_new.pt')
model.base_model.load_state_dict(state_dict)

model = model.to(device)

In [111]:
print(model)

MLT(
  (base_model): EfficientNet(
    (conv_stem): Conv2dSame(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn1): BatchNormAct2d(
      32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): SiLU(inplace=True)
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNormAct2d(
            32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 

In [112]:
for param in model.base_model.parameters():
    param.requires_grad = False

# for param in model.fc_gaze.parameters():
#     param.requires_grad = False
# for param in model.gaze_regressor.parameters():
#     param.requires_grad = False

In [113]:
batch_size = 64 #48# 32# 32 #16 #8 #
epochs = 10
lr = 0.001
gamma = 0.7
seed = 42

In [114]:
def label_smooth(target, n_classes: int, label_smoothing=0.1):
    # convert to one-hot
    batch_size = target.size(0)
    target = torch.unsqueeze(target, 1)
    soft_target = torch.zeros((batch_size, n_classes), device=target.device)
    soft_target.scatter_(1, target, 1)
    # label smoothing
    soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
    return soft_target

def cross_entropy_loss_with_soft_target(pred, soft_target):
    #logsoftmax = nn.LogSoftmax(dim=-1)
    return torch.mean(torch.sum(- weights*soft_target * torch.nn.functional.log_softmax(pred, -1), 1))

def cross_entropy_with_label_smoothing(pred, target):
    soft_target = label_smooth(target, pred.size(1)) #num_classes) #
    return cross_entropy_loss_with_soft_target(pred, soft_target)

criterion=cross_entropy_with_label_smoothing

In [115]:
# Merged Training

In [116]:
import torch.nn.functional as F
import torch.optim as optim

optimizer=optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)


def test_expr(model, device, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0
    with torch.no_grad():  # No gradients needed for testing
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            outputs = model(data)
            emotion_output = outputs[0]  # Only use emotion_output
            
            # Calculate loss
            test_loss += F.cross_entropy(emotion_output, target, reduction='sum').item()
            
            # Calculate accuracy
            pred = emotion_output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    # return test_loss, accuracy
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)')

def train(model, device, train_loader_gaze, train_loader_expr, optimizer, epoch, alpha=0.5):
    model.train()
    train_loss = 0
    correct_expr = 0
    total_expr = 0
    correct_gaze = 0  # If gaze accuracy is relevant
    total_gaze = 0    # If gaze accuracy is relevant
    
    gaze_iter = iter(train_loader_gaze)
    expr_iter = iter(train_loader_expr)
    
    # Use the maximum length of the loaders to ensure each gets fully iterated over time
    max_len = max(len(train_loader_gaze), len(train_loader_expr))
    
    with tqdm(range(max_len), unit="batch") as tepoch:
        for _ in tepoch:
            tepoch.set_description(f"Epoch {epoch}")
            
            # Handle expression data
            try:
                data_expr, target_expr = next(expr_iter)
                data_expr, target_expr = data_expr.to(device), target_expr.to(device)
                optimizer.zero_grad()
                emotion_output, _ = model(data_expr)
                loss_expr = F.cross_entropy(emotion_output, target_expr)
                loss_expr.backward()
                optimizer.step()
                
                train_loss += loss_expr.item()
                _, predicted = emotion_output.max(1)
                total_expr += target_expr.size(0)
                correct_expr += predicted.eq(target_expr).sum().item()
            except StopIteration:
                pass  # This DataLoader is exhausted for this epoch.
            
            # Handle gaze data
            try:
                data_gaze, target_gaze = next(gaze_iter)
                data_gaze, target_gaze = data_gaze.to(device), target_gaze.to(device)
                optimizer.zero_grad()
                _, gaze_output = model(data_gaze)
                loss_gaze = F.mse_loss(gaze_output, target_gaze)
                loss_gaze.backward()
                optimizer.step()
                
                train_loss += loss_gaze.item()
                # Update gaze accuracy metrics if necessary
            except StopIteration:
                pass  # This DataLoader is exhausted for this epoch.



test_expr(model, device, test_loader_expr)

for epoch in range(1, 5):  # 10 epochs for demonstration
    train(model, device, train_loader_gaze, train_loader_expr, optimizer, epoch)
    test_expr(model, device, test_loader_expr)




Test set: Average loss: 2.0854, Accuracy: 494/3999 (12%)


Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:19<00:00, 12.90batch/s]



Test set: Average loss: 2.6698, Accuracy: 867/3999 (22%)


Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:14<00:00, 16.76batch/s]



Test set: Average loss: 2.0574, Accuracy: 1159/3999 (29%)


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:15<00:00, 16.64batch/s]



Test set: Average loss: 2.4261, Accuracy: 929/3999 (23%)


Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:15<00:00, 16.56batch/s]



Test set: Average loss: 2.4721, Accuracy: 1111/3999 (28%)


In [117]:
test_expr(model, device, test_loader_expr)


Test set: Average loss: 2.4721, Accuracy: 1111/3999 (28%)


In [118]:
for param in model.parameters():
        param.requires_grad = True

optimizer=optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
for epoch in range(1, 11):  
    train(model, device, train_loader_gaze, train_loader_expr, optimizer, epoch)
    test_expr(model, device, test_loader_expr)

Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.28batch/s]



Test set: Average loss: 2.7144, Accuracy: 1170/3999 (29%)


Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.28batch/s]



Test set: Average loss: 2.5518, Accuracy: 1225/3999 (31%)


Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.11batch/s]



Test set: Average loss: 2.8598, Accuracy: 1160/3999 (29%)


Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.17batch/s]



Test set: Average loss: 2.5513, Accuracy: 1239/3999 (31%)


Epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.09batch/s]



Test set: Average loss: 2.6053, Accuracy: 1271/3999 (32%)


Epoch 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:22<00:00, 11.01batch/s]



Test set: Average loss: 2.6709, Accuracy: 1261/3999 (32%)


Epoch 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:24<00:00, 10.37batch/s]



Test set: Average loss: 2.5357, Accuracy: 1296/3999 (32%)


Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:23<00:00, 10.75batch/s]



Test set: Average loss: 2.6011, Accuracy: 1282/3999 (32%)


Epoch 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:18<00:00, 13.85batch/s]



Test set: Average loss: 2.4956, Accuracy: 1320/3999 (33%)


Epoch 10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:17<00:00, 13.93batch/s]



Test set: Average loss: 2.5204, Accuracy: 1304/3999 (33%)
