In [24]:


import torch
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
import pandas as pd

In [25]:
# Prepare the input data
df = pd.read_csv('data/paintings_with_filenames.csv') 
df['full_description'] = df['title'] + ' ' + df['depicts']+ ' ' + df['wga_description']+ ' ' + df['article_text']
# Filter out NaN values from 'full_description'
df = df.dropna(subset=['full_description'])
df.head()


Unnamed: 0,item,title,author_wikidata,author_name,creation_date,origin_country,display_country,display_location,type,school,...,wiki_url,image_url,depicts,wikipedia_url,article_text,title_clean,wga_url,wga_description,filename,full_description
3,http://www.wikidata.org/entity/Q734082,Regatta at Sainte-Adresse,http://www.wikidata.org/entity/Q296,Claude Monet,1867-01-01T00:00:00Z,,United States of America,Metropolitan Museum of Art,marine art,Impressionism,...,,https://commons.wikimedia.org/wiki/Special:Fil...,"parasol, sailboat, Sainte-Adresse, church, mar...",https://en.wikipedia.org/wiki/Regatta_at_Saint...,The Regatta at Sainte-Adresse is an oil-on-can...,regatta at sainteadresse,https://www.wga.hu/html/m/monet/01/early16.html,"In the summer of 1867, Monet painted a number ...","Claude_Monet,_1867,_Regatta_at_Sainte-Adresse,...","Regatta at Sainte-Adresse parasol, sailboat, S..."
4,http://www.wikidata.org/entity/Q472037,By the Seashore,http://www.wikidata.org/entity/Q39931,Pierre-Auguste Renoir,1883-01-01T00:00:00Z,,United States of America,Metropolitan Museum of Art,portrait,Impressionism,...,,https://commons.wikimedia.org/wiki/Special:Fil...,"portrait, Saint Peter Port, coast, chair, woman",https://en.wikipedia.org/wiki/By_the_Seashore,By the Seashore is a painting by Pierre-August...,by the seashore,https://www.wga.hu/html/r/renoir/3/3renoi20.html,This canvas was painted in the artist's studio...,Pierre-Auguste_Renoir_-_Femme_assise_au_bord_d...,"By the Seashore portrait, Saint Peter Port, co..."
8,http://www.wikidata.org/entity/Q877191,The Three Philosophers,http://www.wikidata.org/entity/Q8459,Giorgione,1500-01-01T00:00:00Z,,Austria,Kunsthistorisches Museum,landscape painting,,...,,https://commons.wikimedia.org/wiki/Special:Fil...,"philosopher, landscape",https://en.wikipedia.org/wiki/The_Three_Philos...,The Three Philosophers is an oil painting on c...,the three philosophers,https://www.wga.hu/html/g/giorgion/various/thr...,The Three Philosophers must be a work of the l...,Giorgione_-_Three_Philosophers_-_Google_Art_Pr...,"The Three Philosophers philosopher, landscape ..."
9,http://www.wikidata.org/entity/Q878981,The Mocking of Christ,http://www.wikidata.org/entity/Q154338,Matthias Grünewald,1504-01-01T00:00:00Z,,Germany,Alte Pinakothek,religious art,,...,,https://commons.wikimedia.org/wiki/Special:Fil...,"Mocking of Jesus, Jesus",https://en.wikipedia.org/wiki/The_Mocking_of_C...,The Mocking of Christ (German: Die Verspottung...,the mocking of christ,https://www.wga.hu/html/g/grunewal/1/04mock.html,Grünewald's earliest datable work is the Mocki...,Mathis_Gothart_Grünewald_062.jpg,"The Mocking of Christ Mocking of Jesus, Jesus ..."
28,http://www.wikidata.org/entity/Q212616,The Raft of the Medusa,http://www.wikidata.org/entity/Q184212,Théodore Géricault,1819-01-01T00:00:00Z,,France,Room 700,marine art,Romanticism,...,,https://commons.wikimedia.org/wiki/Special:Fil...,"sitting, lying, standing, Méduse, agony, raft,...",https://en.wikipedia.org/wiki/The_Raft_of_the_...,The Raft of the Medusa (French: Le Radeau de l...,the raft of the medusa,https://www.wga.hu/html/g/gericaul/1/105geric....,In expressing the predicament of the shipwreck...,JEAN_LOUIS_THÉODORE_GÉRICAULT_-_La_Balsa_de_la...,"The Raft of the Medusa sitting, lying, standin..."


In [26]:


# Define the classes
classes = ['wine', 'beverage', 'meat', 'fruit', 'vegetable', 'bread', 'dairy', 'dessert', 'seafood']

# Load TinyBERT model and tokenizer with multi-label configuration
model_name = 'huawei-noah/TinyBERT_General_4L_312D'
tokenizer = BertTokenizer.from_pretrained(model_name)

config = BertConfig.from_pretrained(
    model_name,
    num_labels=len(classes),
    problem_type="multi_label_classification"
)

model = BertForSequenceClassification.from_pretrained(model_name, config=config)

# Set device to MPS if available
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model.to(device)

# Batch tokenize the text
encoded_dict = tokenizer.batch_encode_plus(
    df['full_description'].tolist(),
    add_special_tokens=True,
    max_length=266,
    padding='max_length',
    truncation=True,
    return_attention_mask=True,
    return_tensors='pt'
)

input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']

# Create DataLoader for batch processing
dataset = TensorDataset(input_ids, attention_masks)
batch_size = 64  # Adjust based on your system's memory
dataloader = DataLoader(
    dataset, sampler=SequentialSampler(dataset), batch_size=batch_size
)

# Classify the descriptions
model.eval()
predictions = []

with torch.no_grad():
    for batch in dataloader:
        b_input_ids = batch[0].to(device)
        b_attention_masks = batch[1].to(device)
        outputs = model(b_input_ids, attention_mask=b_attention_masks)
        logits = outputs.logits
        predictions.append(logits.cpu())

# Concatenate all predictions
predictions = torch.cat(predictions, dim=0)

# Apply sigmoid to get probabilities
probabilities = torch.sigmoid(predictions).numpy()

# Adjust the threshold as needed
threshold = 0.5  # Example higher threshold
df['predicted_classes'] = [
    [classes[i] for i in range(len(classes)) if pred[i] > threshold]
    for pred in probabilities
]

# Display the results
display(df[['full_description', 'predicted_classes']])

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,full_description,predicted_classes
3,"Regatta at Sainte-Adresse parasol, sailboat, S...","[wine, beverage, meat, fruit, dairy, dessert, ..."
4,"By the Seashore portrait, Saint Peter Port, co...","[beverage, meat, fruit, dairy, dessert, seafood]"
8,"The Three Philosophers philosopher, landscape ...","[wine, beverage, meat, fruit, dairy, dessert, ..."
9,"The Mocking of Christ Mocking of Jesus, Jesus ...","[wine, beverage, meat, fruit, dairy, dessert, ..."
28,"The Raft of the Medusa sitting, lying, standin...","[wine, beverage, meat, fruit, dairy, dessert, ..."
...,...,...
119983,"The Finding of Moses Moses, woman This small p...","[beverage, meat, fruit, dairy, dessert, seafood]"
119984,"The Finding of Moses Moses, woman This small p...","[beverage, meat, fruit, dairy, dessert, seafood]"
120524,"Death and the Maiden young adult woman, Death ...","[wine, beverage, meat, fruit, dairy, dessert, ..."
120525,"Death and the Maiden young adult woman, Death ...","[wine, beverage, meat, fruit, dairy, dessert, ..."


In [27]:

certain_length = 5  # Example length
filtered_df = df[df['predicted_classes'].apply(lambda x: len(x) > certain_length)]

# Display the filtered results
display(filtered_df[['full_description', 'predicted_classes']])

Unnamed: 0,full_description,predicted_classes
3,"Regatta at Sainte-Adresse parasol, sailboat, S...","[wine, beverage, meat, fruit, dairy, dessert, ..."
4,"By the Seashore portrait, Saint Peter Port, co...","[beverage, meat, fruit, dairy, dessert, seafood]"
8,"The Three Philosophers philosopher, landscape ...","[wine, beverage, meat, fruit, dairy, dessert, ..."
9,"The Mocking of Christ Mocking of Jesus, Jesus ...","[wine, beverage, meat, fruit, dairy, dessert, ..."
28,"The Raft of the Medusa sitting, lying, standin...","[wine, beverage, meat, fruit, dairy, dessert, ..."
...,...,...
119983,"The Finding of Moses Moses, woman This small p...","[beverage, meat, fruit, dairy, dessert, seafood]"
119984,"The Finding of Moses Moses, woman This small p...","[beverage, meat, fruit, dairy, dessert, seafood]"
120524,"Death and the Maiden young adult woman, Death ...","[wine, beverage, meat, fruit, dairy, dessert, ..."
120525,"Death and the Maiden young adult woman, Death ...","[wine, beverage, meat, fruit, dairy, dessert, ..."


In [28]:
df.to_csv('data/predicted_classes.csv', index=False)