In [None]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
import pandas as pd
from transformers import BertModel
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)



In [None]:
num_epochs = 1

batch_size = 8


In [None]:
import torch
from PIL import Image

class MyDataset(torch.utils.data.Dataset):
  'Caractérise un jeu de données pour PyTorch'
  def __init__(self, transforms = None, root_dir = 'data', mode = 'train'):
        'Initialisation'
        self.df = pd.read_json(f"{root_dir}/{mode}.jsonl", lines=True)
        self.labels = self.df.label
        self.image_names = self.df.img
        self.transforms = transforms
        self.root_dir = root_dir
#        print(len(self.labels[self.labels == 0])/len(self.labels))
      
  def __len__(self):
        "Représente le nombre total d'exemples du jeu de données"
        return len(self.labels)

  def __getitem__(self, idx):
      'Génère un exemple à partir du jeu de données'
      # Sélection de l'exemple
      if torch.is_tensor(idx):
            idx = idx.tolist()

      image_path = f"{self.root_dir}/{self.image_names.iloc[idx]}"

      img = Image.open(image_path, ).convert('RGB')

      if self.transforms :
            img = self.transforms(img)

      return img, self.labels.iloc[idx]



In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
alpha_buster = lambda x: x[:3, :, :]
input_size = 224
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        alpha_buster
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        alpha_buster
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: MyDataset(transforms = data_transforms[x], mode = x) for x in ['train', 'val']}

# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:

# définir le modèle Bert
bert_model = BertModel.from_pretrained('bert-base-uncased')

# définir le modèle ResNet18
resnet18_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)


In [None]:
# définir une classe de modèle qui effectue la moyenne des sorties des deux modèles
class FusionModel(torch.nn.Module):
    def __init__(self, bert_model, resnet18_model):
        super(FusionModel, self).__init__()
        self.bert_model = bert_model
        self.resnet18_model = resnet18_model
        self.fc = torch.nn.Linear(768+512, 1) # 768 est la dimension de sortie de Bert, 512 est la dimension de sortie de ResNet18, 1 est le nombre de classes pour la classification binaire
        
    def forward(self, input_ids, attention_mask, image):
        bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[1]
        resnet_output = self.resnet18_model(image)
        fusion_output = torch.cat((bert_output, resnet_output), dim=1)
        output = self.fc(fusion_output)
        output = torch.sigmoid(output)
        return output




# initialiser une instance de FusionModel
fusion_model = FusionModel(bert_model, resnet18_model)
print(fusion_model)

In [None]:
def train_late_fusion(model, dataloaders, resnet, bert, criterion, optimizer, num_epochs):
    since = time.time()
    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        for phase in ['train', 'test']:

            for inputs, input_ids, attention_masks, labels in dataloaders:
                # passer les images à travers ResNet18
                inputs = inputs.to(device)
                resnet_output = resnet18_model(inputs)
                
                # passer les textes à travers Bert
                input_ids = input_ids.to(device)
                attention_masks = attention_masks.to(device)
                bert_output = bert_model(input_ids=input_ids, attention_mask=attention_masks)[1]
                
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # effectuer la fusion
                    fusion_output = torch.cat((bert_output, resnet_output), dim=1)
                    # effectuer la prédiction et calculer la perte
                    outputs = fusion_model(input_ids, attention_masks, inputs)
                    loss = criterion(outputs, labels.float().unsqueeze(1).to(device))
                
                    if phase == 'train':
                        # rétropropager et mettre à jour les poids
                        loss.backward()
                        optimizer.step()
                    preds = (outputs > 0.5).float()
                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
                # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
                val_acc_history.append(epoch_acc)

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)

    return fusion_model, 

In [None]:

# définir la fonction de perte et l'optimiseur
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(fusion_model.parameters(), lr=0.001)


In [None]:
fusion_model, hist = train_late_fusion(fusion_model, dataloaders_dict, resnet18_model, bert_model , criterion, optimizer, num_epochs)