In [None]:
!pip install transformers

In [None]:
from nltk.tokenize import sent_tokenize, word_tokenize
from transformers import DistilBertTokenizer, DistilBertModel, BertTokenizer, BertModel
from sklearn.metrics import classification_report
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from sklearn.utils import class_weight
from torchviz import make_dot
import torch.nn.functional as Func
from nltk.stem import PorterStemmer
import torchvision.models as models
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import pandas as pd
import numpy as np
import torchvision
import itertools
import shutil
import string
import pickle
import torch
import nltk
import json
import re
import os

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
class DepressionClassifier(nn.Module):
    def __init__(self, num_epochs, train_loader, val_loader, num_classes=2, dropout = 0.5):
      super(DepressionClassifier, self).__init__()
      self.num_classes = num_classes
      self.dropout = dropout
      self.num_epochs = num_epochs
      
      # Image model
      self.image_model = models.resnet50(pretrained = True).to(device)
      self.image_model.fc = nn.Linear(self.image_model.fc.in_features, 512)
      
      # Text model
      self.text_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
      self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
      
      # Classifier
      self.classifier = nn.Sequential(
          nn.Linear(512 + 768, 256),
          
          nn.ReLU(),
          nn.Dropout(dropout),
          nn.Linear(256, self.num_classes)
      )
      self.criterion = nn.CrossEntropyLoss()

      # data loaders
      self.train_loader = train_loader
      self.val_loader = val_loader

      # loss and accuracy lists
      self.train_loss, self.train_accuracy = [], []
      self.val_loss, self.val_accuracy = [], []


    # forward pass
    def forward(self, images, texts):
      images = images.to(device)
      image_features = self.image_model(images)

      text_input_ids = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt').input_ids
      text_input_ids = text_input_ids.to(device)

      text_features = self.text_model(input_ids=text_input_ids).last_hidden_state[:, 0, :]
      features = torch.cat([image_features, text_features], dim=1)
      outputs = self.classifier(features)
      return outputs


    # predict method
    def predict(self, image, text):
      image = torch.reshape(image, (1, 3, 224, 224))
      output = self(image, text)
      print(output)
      predicted_class = torch.softmax(output, dim=1).argmax(dim=1)
      return predicted_class


    def train_model(self, optimizer):
      n_total_steps = len(self.train_loader)

      for epoch in range(self.num_epochs):
        # training section
        self.train()
        running_loss, n_correct, n_samples = 0.0, 0, 0

        for index, (images, texts, labels) in enumerate(self.train_loader):
          # Forward pass
          images = images.to(device)
          labels = labels.to(device)

          outputs = self(images, texts)
          loss = self.criterion(outputs, labels)

          # Backward and optimize
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          running_loss += loss.item()
          _, predicted = torch.max(outputs, 1)

          n_samples += labels.size(0)
          n_correct += (predicted == labels).sum().item()

          if (index+1) % 266 == 0:
            print(f'Epoch --> [{epoch+1}/{num_epochs}] | Step --> [{index+1}/{n_total_steps}] | Loss --> {loss.item():.4f} | Accuracy --> {n_correct/n_samples:.4f}')

        training_loss = running_loss/n_total_steps
        accuracy = 100.0 * n_correct / n_samples

        self.train_loss.append(round(training_loss, 4))
        self.train_accuracy.append(accuracy)

        # validation section, as we have activated the model in the evaluation mode using .eval()
        self.eval()
        running_val_loss = 0.0
        n_correct, n_samples = 0.0, 0.0

        with torch.no_grad():
          for batch_index, (images, texts, labels) in enumerate(self.val_loader):
            images = images.to(device)
            labels = labels.to(device)

            outputs = self(images, texts)
            val_loss = self.criterion(outputs, labels)

            running_val_loss += val_loss.item()
            _, predicted = torch.max(outputs, 1)

            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()

          avg_val_loss = running_val_loss / len(self.val_loader)
          accuracy = 100.0 * n_correct / n_samples

          self.val_loss.append(avg_val_loss)
          self.val_accuracy.append(accuracy)

      print(f"\nValidation Accuracy of the Network: {sum(self.val_accuracy)/self.num_epochs} %\n")
      print("Training Complete !!")
      return True


    def test_model(self, loader):
      # test_loss, test_accuracy = [], []
      predicted_labels, actual_labels = [], []

      self.eval()
      running_test_loss = 0.0
      n_correct, n_samples = 0.0, 0.0

      with torch.no_grad():
        for batch_index, (images, texts, labels) in enumerate(loader):
          images = images.to(device)
          labels = labels.to(device)

          outputs = self(images, texts)
          temp_test_loss = self.criterion(outputs, labels)

          running_test_loss += temp_test_loss.item()

          _, predicted = torch.max(outputs, 1)
          n_samples += labels.size(0)
          n_correct += (predicted == labels).sum().item()

          predicted_labels.append(list(predicted))
          actual_labels.append(list(labels))

      avg_test_loss = running_test_loss / len(loader)
      accuracy = 100.0 * n_correct / n_samples

      # test_loss.append(avg_test_loss)
      # test_accuracy.append(accuracy)
      predicted_labels = list(itertools.chain(*predicted_labels))
      actual_labels = list(itertools.chain(*actual_labels))

      self.train()
      print(f"Test Accuracy of the Network: {accuracy} %\n")
      return predicted_labels, actual_labels
      

    # utility method for plotting loss and epochs
    def plot_loss_vs_epoch(self):
        plt.figure()
        plt.plot(range(self.num_epochs), self.val_loss, 'b', label = 'Validation Loss')
        plt.plot(range(self.num_epochs), self.train_loss, 'r', label = 'Training Loss')
        plt.xlabel("Number of Epochs")
        plt.ylabel("Loss")
        plt.title("EPOCH VS LOSS PLOT")
        plt.legend()
        plt.grid()
        plt.show()


    # utility method for plotting accuracies and epochs
    def plot_accuracy_vs_epoch(self):
        plt.figure()
        plt.plot(range(self.num_epochs), self.val_accuracy, 'b', label = 'Validation Accuracy')
        plt.plot(range(self.num_epochs), self.train_accuracy, 'r', label = 'Training Accuracy')
        plt.xlabel("Number of Epochs")
        plt.ylabel("Accuracy")
        plt.title("EPOCH VS ACCURACY PLOT")
        plt.legend()
        plt.grid()
        plt.show()

In [None]:
model = torch.load('/content/ir_project_model_1.pth').to(device)

In [None]:
sample_image = images[3]
sample_text = texts[3]

In [None]:
prediction = model.predict(sample_image, sample_text)
actual_label = labels[3]

In [None]:
print(f"Prediction: {prediction}")
print(f"Actual: {actual_label}")