In [None]:
import torch
from transformers import pipeline
import pandas as pd
from tqdm import tqdm
import re


In [None]:
# Prepare the input data
df = pd.read_csv('data/paintings_with_filenames.csv') 
df

In [None]:

def extract_relevant_sections(text):
   
    if not isinstance(text, str):
        return None

    # Define regular expression patterns to match each section
    intro_pattern = r'^(.*?)(?=\n==)'
    description_pattern = r'== Description ==\n(.*?)(?=\n==|\Z)'
    interpretation_pattern = r'== Interpretations ==\n(.*?)(?=\n==|\Z)'

    # Search for each section
    intro = re.search(intro_pattern, text, re.DOTALL)
    description = re.search(description_pattern, text, re.DOTALL)
    interpretation = re.search(interpretation_pattern, text, re.DOTALL)

    # Extract the sections, if found
    result = ""
    if intro:
        result += intro.group(1).strip() + "\n\n"
    if description:
        result += "== Description ==\n" + description.group(1).strip() + "\n\n"
    if interpretation:
        result += "== Subject ==\n" + interpretation.group(1).strip()

    return result.strip()


In [None]:

# Fill NaN values with an empty string
""" df['title'] = df['title'].fillna('')
df['depicts'] = df['depicts'].fillna('')
df['wga_description'] = df['wga_description'].fillna('')
df['article_text'] = df['article_text'].fillna('') """

df['wiki_description'] = df['article_text'].apply(extract_relevant_sections)

df['full_description'] = df['title'] + ' ' + df['depicts'] + ' ' + df['wga_description'] + ' ' + df['wiki_description']

# Filter out NaN values from 'full_description'
df = df.dropna(subset=['full_description'])
df.head()

In [None]:


# Check for MPS availability
device = 0 if torch.backends.mps.is_available() else -1
print(f"Using device: {'MPS' if device == 0 else 'CPU'}")

# Initialize the zero-shot classifier with a suitable model
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",  # Choose lighter models if needed
    device=device
)


# Define candidate labels
classes = ['wine', 'beverage', 'meat', 'fruit', 'vegetable', 'bread', 'dairy', 'dessert', 'seafood']

# Define threshold
threshold = 0.6

# Batch size for processing
batch_size = 16

# Prepare batches
descriptions = df['full_description'].tolist()

sample_range = range(0, len(descriptions), batch_size)

batches = [descriptions[i:i + batch_size] for i in sample_range]

predicted_classes = []

for batch in tqdm(batches, desc="Classifying batches"):
    results = classifier(batch, classes, multi_label=True)
    for result in results:
        predicted = [label for label, score in zip(result['labels'], result['scores']) if score > threshold]
        predicted_classes.append(predicted)

df['predicted_classes'] = predicted_classes

# Display the results
print(df[['full_description', 'predicted_classes']])

In [None]:

certain_length = 5  # Example length
filtered_df = df[df['predicted_classes'].apply(lambda x: len(x) > certain_length)]

# Display the filtered results
display(filtered_df[['full_description', 'predicted_classes']])

In [None]:
df.to_csv('data/predicted_classes.csv', index=False)

In [None]:
filtered_df.to_csv('data/filtered_predicted_classes.csv', index=False)