In [7]:
import pandas as pd
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
import numpy as np


In [None]:
# Prepare the input data
df = pd.read_csv('data/paintings_with_filenames.csv') 
df['full_description'] = df['title'] + ' ' + df['depicts']+ ' ' + df['wga_description']+ ' ' + df['article_text']
# Filter out NaN values from 'full_description'
df = df.dropna(subset=['full_description'])
df.head()


In [None]:

# Define the classes
classes = ['wine', 'beverage', 'meat', 'fruit', 'vegetable', 'bread', 'dairy', 'dessert', 'seafood']

# Load the DistilBERT model and tokenizer
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=len(classes))
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')


# Tokenize the text
input_ids = []
attention_masks = []



for text in df['full_description']:
    encoded_dict = tokenizer.encode_plus(
                        text,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 512,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks to parallel token_type_ids.
                        return_tensors = 'pt',     # Return pytorch tensors.
                   )
    
    input_ids.append(encoded_dict['input_text'])
    attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)

# Classify the paintings
model.eval()
with torch.no_grad():
    logits = model(input_ids, attention_mask=attention_masks)[0]

predictions = torch.sigmoid(logits).detach().cpu().numpy()

# Apply a threshold to get the predicted classes
df['predicted_classes'] = [[classes[i] for i in range(len(classes)) if pred[i] > 0.5] for pred in predictions]

# Display the results
display(df)