# Multi-Input Fusion Model for 5-class Classification

This will combine the facial feature attributes with the image data itself.

Idea is as follows:

1) Transformer for Image Processing: Use a Vision Transformer (ViT) as the image processing component. ViTs divide the image into patches and process these through self-attention mechanisms, which can capture complex patterns and relationships within the image data.

2) Dense Network for Structured Data: Implement a deep neural network (DNN) with multiple dense layers to process the structured features from JSON array. This pathway will handle the attributes, labels, and conditions associated with each image.

3) Fusion Layer: After processing the image and structured data separately, combine their representations using a fusion layer.

4) Output Layer

## Prep

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
from PIL import Image
import json
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, recall_score, accuracy_score
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from sklearn.metrics import recall_score
from torchvision.models import vit_b_16  # Pretrained ViT model from torchvision
from tqdm import tqdm


In [None]:
!gdown "https://drive.google.com/uc?id=1P-ypjfLTZsEpBSyMdzoeThlWd-l1a5rN"

Downloading...
From (original): https://drive.google.com/uc?id=1P-ypjfLTZsEpBSyMdzoeThlWd-l1a5rN
From (redirected): https://drive.google.com/uc?id=1P-ypjfLTZsEpBSyMdzoeThlWd-l1a5rN&confirm=t&uuid=8f96dcc2-4d97-4637-8ea8-5ca177c10749
To: /content/preprocessed_data-20240325T171740Z-001.zip
100% 415M/415M [00:02<00:00, 140MB/s]


In [None]:
!ls

preprocessed_data-20240325T171740Z-001.zip  sample_data


In [None]:
!unzip preprocessed_data-20240325T171740Z-001.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: preprocessed_data/spoof/026420.jpg  
  inflating: preprocessed_data/spoof/082396.jpg  
  inflating: preprocessed_data/spoof/293551.jpg  
  inflating: preprocessed_data/spoof/169445.jpg  
  inflating: preprocessed_data/spoof/380539.jpg  
  inflating: preprocessed_data/spoof/071578.jpg  
  inflating: preprocessed_data/spoof/294558.jpg  
  inflating: preprocessed_data/spoof/220943.jpg  
  inflating: preprocessed_data/spoof/114486.jpg  
  inflating: preprocessed_data/spoof/083088.jpg  
  inflating: preprocessed_data/spoof/489034.jpg  
  inflating: preprocessed_data/spoof/076571.jpg  
  inflating: preprocessed_data/spoof/134030.jpg  
  inflating: preprocessed_data/spoof/126434.jpg  
  inflating: preprocessed_data/spoof/347099.jpg  
  inflating: preprocessed_data/spoof/281633.jpg  
  inflating: preprocessed_data/spoof/358152.jpg  
  inflating: preprocessed_data/spoof/036155.jpg  
  inflating: preprocessed_data/spoo

In [None]:
!ls preprocessed_data

live  spoof


In [None]:
!ls preprocessed_data/spoof

000039.jpg  060775.jpg	122985.jpg  184143.jpg	244940.jpg  307312.jpg	370371.jpg  432273.jpg
000060.jpg  060780.jpg	123015.jpg  184167.jpg	244953.jpg  307327.jpg	370376.jpg  432286.jpg
000085.jpg  060790.jpg	123026.jpg  184182.jpg	244965.jpg  307330.jpg	370408.jpg  432288.jpg
000094.jpg  060800.jpg	123056.jpg  184193.jpg	244966.jpg  307340.jpg	370435.jpg  432293.jpg
000106.jpg  060833.jpg	123073.jpg  184206.jpg	244972.jpg  307346.jpg	370436.jpg  432297.jpg
000143.jpg  060835.jpg	123077.jpg  184215.jpg	244979.jpg  307366.jpg	370456.jpg  432307.jpg
000152.jpg  060837.jpg	123087.jpg  184216.jpg	244984.jpg  307367.jpg	370457.jpg  432310.jpg
000158.jpg  060841.jpg	123093.jpg  184223.jpg	244991.jpg  307399.jpg	370473.jpg  432319.jpg
000160.jpg  060873.jpg	123095.jpg  184224.jpg	245004.jpg  307427.jpg	370502.jpg  432332.jpg
000186.jpg  060920.jpg	123100.jpg  184245.jpg	245020.jpg  307428.jpg	370504.jpg  432335.jpg
000187.jpg  060939.jpg	123104.jpg  184268.jpg	245035.jpg  307429.jpg	370505.jpg 

In [None]:
import torch
from torch.utils.data import Dataset
import json
from PIL import Image
import os

class SpoofDataset(Dataset):
    def __init__(self, data_dir, label_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        with open(label_file, 'r') as f:
            self.labels_and_features = json.load(f)

        self.img_paths = []
        self.features = []
        self.spoof_labels = []
        self.additional_labels = []

        for img_name, data in self.labels_and_features.items():
            # Correct the file extension from .png to .jpg
            img_name = img_name.replace('.png', '.jpg')
            # Determine the subfolder based on the spoof type label
            subfolder = 'live' if data[40] == 0 else 'spoof'
            full_path = os.path.join(data_dir, subfolder, img_name)
            if os.path.exists(full_path):
                self.img_paths.append(full_path)
                self.features.append(data[:40])  # Extract the first 40 face attribute labels
                self.spoof_labels.append(data[40])  # Spoof type label is at position 40
                self.additional_labels.append(data[41:43])  # Illumination and environment labels

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert('RGB')
        features = self.features[idx]
        spoof_label = self.spoof_labels[idx]
        additional_labels = self.additional_labels[idx]

        if self.transform:
            image = self.transform(image)

        # Convert features and labels to PyTorch tensors
        features_tensor = torch.tensor(features, dtype=torch.float)
        spoof_label_tensor = torch.tensor(spoof_label, dtype=torch.long)
        additional_labels_tensor = torch.tensor(additional_labels, dtype=torch.long)

        return image, features_tensor, spoof_label_tensor, additional_labels


In [None]:
class CombinedModel(nn.Module):
    def __init__(self):
        super(CombinedModel, self).__init__()
        # Load a pre-trained Vision Transformer
        self.vit = vit_b_16(pretrained=True)
        # Replace the head of the ViT
        self.vit.heads = nn.Linear(self.vit.heads[0].in_features, 256)

        # Dense Network for structured data
        self.dense = nn.Sequential(
            nn.Linear(40, 128),  # 40 features as input
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 256),
            nn.ReLU()
        )

        # Fusion and Classifier
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),  # Fusion of 256 (ViT) + 256 (Dense)
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 5)  # Output layer for 5 classes
        )

    def forward(self, images, features):
        img_features = self.vit(images)
        dense_features = self.dense(features)
        combined_features = torch.cat((img_features, dense_features), dim=1)
        output = self.classifier(combined_features)
        return output


In [None]:
model = CombinedModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:03<00:00, 107MB/s]


In [None]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

def create_dataloaders(data_dir, label_file, batch_size=32, val_split=0.2, transform=None):
    full_dataset = SpoofDataset(data_dir=data_dir, label_file=label_file, transform=transform)

    total_size = len(full_dataset)
    val_size = int(total_size * val_split)
    train_size = total_size - val_size

    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return {'train': train_loader, 'val': val_loader}


In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to fit ViT input requirements
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
data_dir = 'preprocessed_data'
label_file = 'train_json_complete.json'

# Create dataloaders
dataloaders = create_dataloaders(data_dir, label_file, batch_size=32, val_split=0.2, transform=transform)

# Define device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
best_metrics = {'val_loss': float('inf'), 'val_accuracy': 0, 'val_precision': 0, 'val_recall': 0, 'val_f1': 0}

In [None]:
# Training loop
for epoch in range(10):
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()
            dataloader = dataloaders['train']
        else:
            model.eval()
            dataloader = dataloaders['val']

        running_loss = 0.0
        all_preds = []
        all_labels = []

        # Wrapping the dataloader with tqdm for a progress bar
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} {phase.upper()}", leave=False)
        for data in progress_bar:
            images, features, spoof_labels, additional_labels = data
            images, features, spoof_labels = images.to(device), features.to(device), spoof_labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(images, features)
                loss = criterion(outputs, spoof_labels)
                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * images.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(spoof_labels.cpu().numpy())

            # Update progress bar description with current loss
            progress_bar.set_description(f"Epoch {epoch+1} {phase.upper()} Loss: {loss.item():.4f}")

        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_accuracy = accuracy_score(all_labels, all_preds)
        epoch_precision = precision_score(all_labels, all_preds, zero_division=0, average='macro')
        epoch_recall = recall_score(all_labels, all_preds, zero_division=0, average='macro')
        epoch_f1 = f1_score(all_labels, all_preds, zero_division=0, average='macro')

        print(f'Epoch {epoch+1} {phase.upper()} Loss: {epoch_loss:.4f} Accuracy: {epoch_accuracy:.4f} Precision: {epoch_precision:.4f} Recall: {epoch_recall:.4f} F1: {epoch_f1:.4f}')

        # Update best metrics for validation phase
        if phase == 'val' and (epoch_loss < best_metrics['val_loss'] or epoch_recall > best_metrics['val_recall']):
            best_metrics.update({
                'epoch': epoch + 1,
                'val_loss': epoch_loss,
                'val_accuracy': epoch_accuracy,
                'val_precision': epoch_precision,
                'val_recall': epoch_recall,
                'val_f1': epoch_f1,
            })

print(f"Best Metrics at Epoch {best_metrics['epoch']}:")
for metric, value in best_metrics.items():
    if metric != 'epoch':
        print(f"{metric.capitalize()}: {value:.4f}")




Epoch 1 TRAIN Loss: 0.1020 Accuracy: 0.9770 Precision: 0.7221 Recall: 0.6670 F1: 0.6616




Epoch 1 VAL Loss: 0.0900 Accuracy: 0.9766 Precision: 0.6552 Recall: 0.6667 F1: 0.6609




Epoch 2 TRAIN Loss: 0.0860 Accuracy: 0.9770 Precision: 0.8084 Recall: 0.6738 F1: 0.6751




Epoch 2 VAL Loss: 0.0735 Accuracy: 0.9782 Precision: 0.9460 Recall: 0.6931 F1: 0.7099




Epoch 3 TRAIN Loss: 0.0773 Accuracy: 0.9775 Precision: 0.8425 Recall: 0.6966 F1: 0.7139




Epoch 3 VAL Loss: 0.0802 Accuracy: 0.9766 Precision: 0.6552 Recall: 0.6667 F1: 0.6609




Epoch 4 TRAIN Loss: 0.0792 Accuracy: 0.9773 Precision: 0.8343 Recall: 0.6896 F1: 0.7028




Epoch 4 VAL Loss: 0.0704 Accuracy: 0.9777 Precision: 0.9653 Recall: 0.6839 F1: 0.6938




Epoch 5 TRAIN Loss: 0.0821 Accuracy: 0.9777 Precision: 0.8637 Recall: 0.6885 F1: 0.7014




Epoch 5 VAL Loss: 0.0855 Accuracy: 0.9780 Precision: 0.8944 Recall: 0.6994 F1: 0.7195




Epoch 6 TRAIN Loss: 0.0734 Accuracy: 0.9777 Precision: 0.8539 Recall: 0.6941 F1: 0.7103




Epoch 6 VAL Loss: 0.0696 Accuracy: 0.9781 Precision: 0.9118 Recall: 0.6969 F1: 0.7158




Epoch 7 TRAIN Loss: 0.0715 Accuracy: 0.9778 Precision: 0.8535 Recall: 0.7010 F1: 0.7208




Epoch 7 VAL Loss: 0.0747 Accuracy: 0.9766 Precision: 0.6552 Recall: 0.6667 F1: 0.6609




Epoch 8 TRAIN Loss: 0.0673 Accuracy: 0.9780 Precision: 0.8574 Recall: 0.7077 F1: 0.7306




Epoch 8 VAL Loss: 0.0649 Accuracy: 0.9781 Precision: 0.9893 Recall: 0.6879 F1: 0.7012




Epoch 9 TRAIN Loss: 0.0676 Accuracy: 0.9778 Precision: 0.8492 Recall: 0.7059 F1: 0.7277




Epoch 9 VAL Loss: 0.0653 Accuracy: 0.9788 Precision: 0.9086 Recall: 0.7113 F1: 0.7379




Epoch 10 TRAIN Loss: 0.0653 Accuracy: 0.9777 Precision: 0.8423 Recall: 0.7140 F1: 0.7382




Epoch 10 VAL Loss: 0.0624 Accuracy: 0.9794 Precision: 0.9072 Recall: 0.7257 F1: 0.7580
Best Metrics at Epoch 10:
Val_loss: 0.0624
Val_accuracy: 0.9794
Val_precision: 0.9072
Val_recall: 0.7257
Val_f1: 0.7580
