In [1]:
import torch
from PIL import Image
import requests
from lavis.models import load_model_and_preprocess

from torch import nn
import pandas as pd
import os

from tqdm import tqdm

from sklearn.metrics import classification_report
import json

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from torch.utils.data import Dataset

# Load the model

In [2]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [None]:
model, vis_processors, text_processors = load_model_and_preprocess(
    "blip2_image_text_matching", "pretrain", device=device, is_eval=True)

# Define the Dataset class

In [None]:
class TwitterCOMMsDataset(Dataset):
    def __init__(self, csv_path, img_dir):
        """
        Args:
            csv_path (string): Path to the {train_completed|val_completed}.csv file.
            image_folder_dir (string): Directory containing the images
        """
        self.df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        caption = item['full_text']
        img_filename = item['filename']
        topic = item['topic']
        label = item['falsified']
        domain = topic.split('_')[0]
        diff = topic.split('_')[1]
        
        try:
            raw_image = Image.open(os.path.join(img_dir, img_filename)).convert('RGB')
        except IOError as e:
            print(e)
        
        img_emb = vis_processors["train"](raw_image).unsqueeze(0)
        txt_emb = text_processors["train"](caption)
        
        return {"text": txt_emb, "image": img_emb, "topic": topic, "label": label, "domain": domain, "difficulty": diff}
        
    