In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import os

# Load and preprocess data
csv_path = "gs://vino-verdict/data/cleaned_wine_df.csv"
df = pd.read_csv(csv_path)

# Drop rows with missing descriptions
df = df.dropna(subset=['description'])

# Binning
bins = [0, 85, 90, 100]
labels = ['low', 'medium', 'high']
df['rating_category'] = pd.cut(df['points'], bins=bins, labels=labels, include_lowest=True)

In [None]:
# Tokenization
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
input_data = tokenizer(list(df['description']), padding=True, truncation=True, return_tensors="pt", max_length=512)
input_ids, attention_mask = input_data["input_ids"], input_data["attention_mask"]

# Map labels to integers
label_map = {'low': 0, 'medium': 1, 'high': 2}
labels = df['rating_category'].map(label_map).values

In [None]:
# 10-fold cross-validation setup
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
fold = 0

# Placeholder for average metrics over 10-folds
avg_accuracy = 0
avg_f1 = 0
avg_precision = 0
avg_recall = 0

# Lists to store training and validation losses for all folds
all_train_losses = []
all_val_losses = []

# Placeholder for the best validation accuracy and its corresponding model
best_val_accuracy = 0
best_model_state = None

# Training and validation
for train_index, val_index in skf.split(input_ids, labels):
    fold += 1
    print(f"Starting fold {fold}...")

    # Split data into train and validation sets for this fold
    train_inputs = input_ids[train_index]
    train_labels = labels[train_index]
    val_inputs = input_ids[val_index]
    val_labels = labels[val_index]

    # Convert data into DataLoader format
    train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=batch_size)
    val_dataloader = DataLoader(valid_dataset, sampler=SequentialSampler(valid_dataset), batch_size=batch_size)

    # Initialize model and optimizer
    model = BertForSequenceClassification(config=config).to('cuda')
    optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
    
    # Lists to store training and validation losses for this fold
    train_losses = []
    val_losses = []

    # Training loop (for simplicity, only one epoch per fold)
    model.train()
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        inputs, masks, labels = batch
        inputs, masks, labels = inputs.to('cuda'), masks.to('cuda'), labels.to('cuda')

        outputs = model(inputs, attention_mask=masks, labels=labels)
        loss = outputs.loss
        train_losses.append(loss.item())
        
        loss.backward()
        optimizer.step()

    # Validation loop
    model.eval()
    val_preds, true_labels = [], []
    
    with torch.no_grad():
        for batch in val_dataloader:
            inputs, masks, labels = batch
            inputs, masks, labels = inputs.to('cuda'), masks.to('cuda'), labels.to('cuda')
            
            outputs = model(inputs, attention_mask=masks, labels=labels)
            loss = outputs.loss
            val_losses.append(loss.item())
            
            logits = outputs.logits
            val_preds.extend(torch.argmax(logits, dim=1).cpu().tolist())
            true_labels.extend(labels.cpu().tolist())

    accuracy = accuracy_score(true_labels, val_preds)
    f1 = f1_score(true_labels, val_preds, average='weighted')
    precision = precision_score(true_labels, val_preds, average='weighted')
    recall = recall_score(true_labels, val_preds, average='weighted')

    avg_accuracy += accuracy
    avg_f1 += f1
    avg_precision += precision
    avg_recall += recall

    print(f"Fold {fold} - Accuracy: {accuracy:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    all_train_losses.extend(train_losses)
    all_val_losses.extend(val_losses)

    # Check if this fold's accuracy is the best we've seen
    if accuracy > best_val_accuracy:
        best_val_accuracy = accuracy
        best_model_state = model.state_dict()
        torch.save(model.state_dict(), "./best_10_fold_model.bin")
        os.system(f"gsutil cp ./best_10_fold_model.bin {path_to_save}")

print(f"Average accuracy over 10 folds: {avg_accuracy / 10:.4f}")
print(f"Average F1 over 10 folds: {avg_f1 / 10:.4f}")
print(f"Average Precision over 10 folds: {avg_precision / 10:.4f}")
print(f"Average Recall over 10 folds: {avg_recall / 10:.4f}")

In [None]:
# Plot learning curve
plt.figure(figsize=(12, 6))
plt.plot(all_train_losses, label='Training loss', color='blue')
plt.plot(all_val_losses, label='Validation loss', color='red')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Learning Curve')
plt.legend()
plt.grid(True)
plt.show()