In [None]:
!pip install datasets transformers wandb

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting wandb
  Downloading wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.16.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x

In [None]:
import csv
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModel
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from Probe import ProbingModel
from Probe import train_probe

In [None]:
def csv_to_dataset(path, do_train_test_split=True):
    data = []
    with open(path, newline='', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            # Replace 'positive' with 1 and 'negative' with 0
            row['Sentiment'] = 1 if row['Sentiment'] == 'POSITIVE' else 0
            data.append(row)

    # Convert the list of dictionaries to a Hugging Face Dataset
    dataset = Dataset.from_dict({key: [d[key] for d in data] for key in data[0]})

    # Rename columns
    dataset = dataset.rename_column("Sentiment", "labels")
    dataset = dataset.rename_column("Text", "text")

    # Train test split
    if do_train_test_split:
        dataset = dataset.train_test_split(test_size=0.1, seed=42)
        # rename splits test -> val
        dataset = DatasetDict({
                    "train": dataset["train"],
                    "val": dataset["test"]})


    return dataset

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    N = predictions.shape[0]
    accuracy = (labels == predictions).sum() / N
    return {"accuracy": accuracy}

In [None]:
import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
run = wandb.init()
artifact = run.use_artifact('n11ch00/ChungliAoSentiment/Chungliao-xlm-roberta-sentiment:v7', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mn1ch0[0m ([33mn11ch00[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact Chungliao-xlm-roberta-sentiment:v7, 682.23MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:14.5


In [None]:
# Hyperparams
BATCH_SIZE=16
LEARNING_RATE = 1e-5
pre_trained_model_name = artifact_dir

In [None]:
#load tokenizer
tokenizer = AutoTokenizer.from_pretrained(pre_trained_model_name)

def tokenize(examples):
          return tokenizer(examples["text"], truncation=True, padding=True)

In [None]:
train_dataset = csv_to_dataset("/content/Chungli_Ao_train_set.csv", do_train_test_split=False)
test_dataset = csv_to_dataset("/content/Test_data_Chungli_ao.csv", do_train_test_split=False)

In [None]:
train_dataset = train_dataset.shuffle(seed=42)
test_dataset = test_dataset.shuffle(seed=42)

In [None]:
train_dataset = train_dataset.map(tokenize, batched=True, batch_size=BATCH_SIZE)
test_dataset = test_dataset.map(tokenize, batched=True, batch_size=BATCH_SIZE)

Map:   0%|          | 0/8579 [00:00<?, ? examples/s]

Map:   0%|          | 0/4095 [00:00<?, ? examples/s]

In [None]:
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [None]:
# create datacollator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
# define dataloaders for train and test set
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=data_collator)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=data_collator)

In [None]:
# Initialize model, loss function, and optimizer
model = SentimentClassifier(pre_trained_model_name, num_classes=2)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# check if gpu available
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
probe_metrics = {}
for i in range(model.pre_trained_model.config.num_hidden_layers):
  print(f"probing layer {i}/{(model.pre_trained_model.config.num_hidden_layers)}")
  # Initialize model, loss function, and optimizer
  probing_model = ProbingModel(model.pre_trained_model.config.hidden_size, num_classes=2)
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(probing_model.parameters(), lr=0.001)
  probe_metrics_i = train_probe(probing_model, model.pre_trained_model, train_loader, test_loader, loss_fn, optimizer, DEVICE, num_epochs=1, hidden_state_layer_index=i)
  probe_metrics[i + 1] = probe_metrics_i




probing layer 0/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00, 10.09it/s]


Epoch 1/1, Training Loss: 0.40602657198905945


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.12it/s]


Epoch 1/1, Validation Loss: 0.8015003628097475, Validation Accuracy: 0.5079
probing layer 1/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.70it/s]


Epoch 1/1, Training Loss: 0.4074917733669281


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.11it/s]


Epoch 1/1, Validation Loss: 0.6799288482870907, Validation Accuracy: 0.5966
probing layer 2/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.88it/s]


Epoch 1/1, Training Loss: 0.31990358233451843


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.02it/s]


Epoch 1/1, Validation Loss: 0.5508350970922038, Validation Accuracy: 0.7282
probing layer 3/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.87it/s]


Epoch 1/1, Training Loss: 0.17153824865818024


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.06it/s]


Epoch 1/1, Validation Loss: 0.4894425696693361, Validation Accuracy: 0.7692
probing layer 4/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.86it/s]


Epoch 1/1, Training Loss: 0.03936019912362099


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.08it/s]


Epoch 1/1, Validation Loss: 0.4845302060130052, Validation Accuracy: 0.7878
probing layer 5/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.84it/s]


Epoch 1/1, Training Loss: 0.010503356344997883


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.06it/s]


Epoch 1/1, Validation Loss: 0.47248877929814626, Validation Accuracy: 0.8083
probing layer 6/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.86it/s]


Epoch 1/1, Training Loss: 0.02486865036189556


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.12it/s]


Epoch 1/1, Validation Loss: 0.5439422304334585, Validation Accuracy: 0.7878
probing layer 7/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.87it/s]


Epoch 1/1, Training Loss: 0.023085087537765503


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.05it/s]


Epoch 1/1, Validation Loss: 0.4994714360946091, Validation Accuracy: 0.8042
probing layer 8/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.86it/s]


Epoch 1/1, Training Loss: 0.00936709064990282


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.07it/s]


Epoch 1/1, Validation Loss: 0.49467485159402713, Validation Accuracy: 0.8073
probing layer 9/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.84it/s]


Epoch 1/1, Training Loss: 0.014044276438653469


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.08it/s]


Epoch 1/1, Validation Loss: 0.5048477016534889, Validation Accuracy: 0.8054
probing layer 10/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.85it/s]


Epoch 1/1, Training Loss: 0.013481899164617062


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.08it/s]


Epoch 1/1, Validation Loss: 0.5499586624282529, Validation Accuracy: 0.8085
probing layer 11/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.87it/s]


Epoch 1/1, Training Loss: 0.01614406146109104


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.09it/s]

Epoch 1/1, Validation Loss: 0.5287534504313953, Validation Accuracy: 0.8085





In [None]:
# mbert low lr linear probe results
probe_metrics

{1: {'train_losses': [0.40602657198905945],
  'val_losses': [0.8015003628097475],
  'val_accuracies': [0.5079365079365079]},
 2: {'train_losses': [0.4074917733669281],
  'val_losses': [0.6799288482870907],
  'val_accuracies': [0.5965811965811966]},
 3: {'train_losses': [0.31990358233451843],
  'val_losses': [0.5508350970922038],
  'val_accuracies': [0.7282051282051282]},
 4: {'train_losses': [0.17153824865818024],
  'val_losses': [0.4894425696693361],
  'val_accuracies': [0.7692307692307693]},
 5: {'train_losses': [0.03936019912362099],
  'val_losses': [0.4845302060130052],
  'val_accuracies': [0.7877899877899878]},
 6: {'train_losses': [0.010503356344997883],
  'val_losses': [0.47248877929814626],
  'val_accuracies': [0.8083028083028083]},
 7: {'train_losses': [0.02486865036189556],
  'val_losses': [0.5439422304334585],
  'val_accuracies': [0.7877899877899878]},
 8: {'train_losses': [0.023085087537765503],
  'val_losses': [0.4994714360946091],
  'val_accuracies': [0.8041514041514042]}

In [None]:
run = wandb.init()
artifact = run.use_artifact('n11ch00/ChungliAoSentiment/Chungliao-xlm-roberta-sentiment:v8', type='model')
artifact_dir = artifact.download()

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: Downloading large artifact Chungliao-xlm-roberta-sentiment:v8, 682.23MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:11.2


In [None]:
pre_trained_model_name = artifact_dir

In [None]:
model = SentimentClassifier(pre_trained_model_name, num_classes=2)

In [None]:
probe_metrics_high_lr = {}
for i in range(model.pre_trained_model.config.num_hidden_layers):
  print(f"probing layer {i}/{(model.pre_trained_model.config.num_hidden_layers)}")
  # Initialize model, loss function, and optimizer
  probing_model = ProbingModel(model.pre_trained_model.config.hidden_size, num_classes=2)
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(probing_model.parameters(), lr=0.001)
  probe_metrics_i_high_lr = train_probe(probing_model, model.pre_trained_model, train_loader, test_loader, loss_fn, optimizer, DEVICE, num_epochs=1, hidden_state_layer_index=i)
  probe_metrics_high_lr[i + 1] = probe_metrics_i_high_lr

probing layer 0/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.67it/s]


Epoch 1/1, Training Loss: 0.4002941846847534


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.24it/s]


Epoch 1/1, Validation Loss: 0.8093096135417, Validation Accuracy: 0.5079
probing layer 1/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.77it/s]


Epoch 1/1, Training Loss: 0.3658834993839264


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.10it/s]


Epoch 1/1, Validation Loss: 0.769353742711246, Validation Accuracy: 0.4879
probing layer 2/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.80it/s]


Epoch 1/1, Training Loss: 0.19871389865875244


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.29it/s]


Epoch 1/1, Validation Loss: 0.8135207642335445, Validation Accuracy: 0.5538
probing layer 3/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.79it/s]


Epoch 1/1, Training Loss: 0.25496602058410645


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.28it/s]


Epoch 1/1, Validation Loss: 0.8579810707597062, Validation Accuracy: 0.5609
probing layer 4/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.78it/s]


Epoch 1/1, Training Loss: 0.4461720883846283


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.25it/s]


Epoch 1/1, Validation Loss: 0.7518998652230948, Validation Accuracy: 0.5538
probing layer 5/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.77it/s]


Epoch 1/1, Training Loss: 0.5933682918548584


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.21it/s]


Epoch 1/1, Validation Loss: 0.6957653716672212, Validation Accuracy: 0.5155
probing layer 6/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.77it/s]


Epoch 1/1, Training Loss: 0.4606015980243683


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.26it/s]


Epoch 1/1, Validation Loss: 0.7816742260474712, Validation Accuracy: 0.5048
probing layer 7/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.78it/s]


Epoch 1/1, Training Loss: 0.41799795627593994


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.23it/s]


Epoch 1/1, Validation Loss: 0.7625687315594405, Validation Accuracy: 0.5079
probing layer 8/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.78it/s]


Epoch 1/1, Training Loss: 0.43419334292411804


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.21it/s]


Epoch 1/1, Validation Loss: 0.769522828864865, Validation Accuracy: 0.5079
probing layer 9/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.76it/s]


Epoch 1/1, Training Loss: 0.44877493381500244


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.21it/s]


Epoch 1/1, Validation Loss: 0.7526882617967203, Validation Accuracy: 0.5079
probing layer 10/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.78it/s]


Epoch 1/1, Training Loss: 0.434135764837265


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.25it/s]


Epoch 1/1, Validation Loss: 0.774161521345377, Validation Accuracy: 0.5079
probing layer 11/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.79it/s]


Epoch 1/1, Training Loss: 0.4858691394329071


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.22it/s]

Epoch 1/1, Validation Loss: 0.7609521937556565, Validation Accuracy: 0.5079





In [None]:
probe_metrics_high_lr

{1: {'train_losses': [0.4002941846847534],
  'val_losses': [0.8093096135417],
  'val_accuracies': [0.5079365079365079]},
 2: {'train_losses': [0.3658834993839264],
  'val_losses': [0.769353742711246],
  'val_accuracies': [0.4879120879120879]},
 3: {'train_losses': [0.19871389865875244],
  'val_losses': [0.8135207642335445],
  'val_accuracies': [0.5538461538461539]},
 4: {'train_losses': [0.25496602058410645],
  'val_losses': [0.8579810707597062],
  'val_accuracies': [0.5609279609279609]},
 5: {'train_losses': [0.4461720883846283],
  'val_losses': [0.7518998652230948],
  'val_accuracies': [0.5538461538461539]},
 6: {'train_losses': [0.5933682918548584],
  'val_losses': [0.6957653716672212],
  'val_accuracies': [0.5155067155067155]},
 7: {'train_losses': [0.4606015980243683],
  'val_losses': [0.7816742260474712],
  'val_accuracies': [0.5047619047619047]},
 8: {'train_losses': [0.41799795627593994],
  'val_losses': [0.7625687315594405],
  'val_accuracies': [0.5079365079365079]},
 9: {'tra

In [None]:
run = wandb.init()
artifact = run.use_artifact('n11ch00/ChungliAoSentiment/Chungliao-xlm-roberta-sentiment:v2', type='model')
artifact_dir = artifact.download()

VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: Downloading large artifact Chungliao-xlm-roberta-sentiment:v2, 682.23MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:10.2


In [None]:
pre_trained_model_name = artifact_dir

In [None]:
model = SentimentClassifier(pre_trained_model_name, num_classes=2)

In [None]:
probe_metrics_chungli_ao = {}
for i in range(model.pre_trained_model.config.num_hidden_layers):
  print(f"probing layer {i}/{(model.pre_trained_model.config.num_hidden_layers)}")
  # Initialize model, loss function, and optimizer
  probing_model = ProbingModel(model.pre_trained_model.config.hidden_size, num_classes=2)
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(probing_model.parameters(), lr=0.001)
  probe_metrics_i_chungli_ao = train_probe(probing_model, model.pre_trained_model, train_loader, test_loader, loss_fn, optimizer, DEVICE, num_epochs=1, hidden_state_layer_index=i)
  probe_metrics_chungli_ao[i + 1] = probe_metrics_i_chungli_ao

probing layer 0/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.87it/s]


Epoch 1/1, Training Loss: 0.4603988230228424


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.15it/s]


Epoch 1/1, Validation Loss: 0.7561907672788948, Validation Accuracy: 0.5079
probing layer 1/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.99it/s]


Epoch 1/1, Training Loss: 0.3646644651889801


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.12it/s]


Epoch 1/1, Validation Loss: 0.6470430149929598, Validation Accuracy: 0.6479
probing layer 2/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.95it/s]


Epoch 1/1, Training Loss: 0.37234124541282654


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.17it/s]


Epoch 1/1, Validation Loss: 0.6430993874091655, Validation Accuracy: 0.6720
probing layer 3/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00, 10.00it/s]


Epoch 1/1, Training Loss: 0.2006334513425827


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.20it/s]


Epoch 1/1, Validation Loss: 0.8432200674433261, Validation Accuracy: 0.6530
probing layer 4/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.98it/s]


Epoch 1/1, Training Loss: 0.25094401836395264


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.13it/s]


Epoch 1/1, Validation Loss: 0.7677480904385448, Validation Accuracy: 0.6872
probing layer 5/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.96it/s]


Epoch 1/1, Training Loss: 0.020940139889717102


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.17it/s]


Epoch 1/1, Validation Loss: 0.5276083782082424, Validation Accuracy: 0.7624
probing layer 6/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.98it/s]


Epoch 1/1, Training Loss: 0.5312010645866394


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.18it/s]


Epoch 1/1, Validation Loss: 0.4944854755885899, Validation Accuracy: 0.7695
probing layer 7/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.99it/s]


Epoch 1/1, Training Loss: 0.36022916436195374


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.20it/s]


Epoch 1/1, Validation Loss: 0.5413723446545191, Validation Accuracy: 0.7177
probing layer 8/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.99it/s]


Epoch 1/1, Training Loss: 0.08203145861625671


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.15it/s]


Epoch 1/1, Validation Loss: 0.48431427439209074, Validation Accuracy: 0.7553
probing layer 9/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.96it/s]


Epoch 1/1, Training Loss: 0.03571152314543724


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.13it/s]


Epoch 1/1, Validation Loss: 0.47632002012687735, Validation Accuracy: 0.7487
probing layer 10/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.97it/s]


Epoch 1/1, Training Loss: 0.00923784077167511


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.23it/s]


Epoch 1/1, Validation Loss: 0.722298288543243, Validation Accuracy: 0.7094
probing layer 11/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:53<00:00,  9.98it/s]


Epoch 1/1, Training Loss: 0.00728055601939559


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 13.14it/s]

Epoch 1/1, Validation Loss: 0.7945759353460744, Validation Accuracy: 0.7062





In [None]:
probe_metrics_chungli_ao

{1: {'train_losses': [0.4897282123565674],
  'val_losses': [0.7522707367315888],
  'val_accuracies': [0.5079365079365079]},
 2: {'train_losses': [0.37024474143981934],
  'val_losses': [0.6426583889406174],
  'val_accuracies': [0.6525030525030525]},
 3: {'train_losses': [0.33135315775871277],
  'val_losses': [0.6482788318535313],
  'val_accuracies': [0.6725274725274726]},
 4: {'train_losses': [0.221183180809021],
  'val_losses': [0.845246305456385],
  'val_accuracies': [0.6517704517704518]},
 5: {'train_losses': [0.37726080417633057],
  'val_losses': [0.7614935035235249],
  'val_accuracies': [0.6871794871794872]},
 6: {'train_losses': [0.02309548854827881],
  'val_losses': [0.5141496948199347],
  'val_accuracies': [0.7748473748473749]},
 7: {'train_losses': [0.40555116534233093],
  'val_losses': [0.4880461299326271],
  'val_accuracies': [0.7770451770451771]},
 8: {'train_losses': [0.3185204267501831],
  'val_losses': [0.5264493980794214],
  'val_accuracies': [0.7262515262515262]},
 9: {

In [None]:
run = wandb.init()
artifact = run.use_artifact('n11ch00/ChungliAoSentiment/Chungliao-xlm-roberta-sentiment:v9', type='model')
artifact_dir = artifact.download()

VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

[34m[1mwandb[0m: Downloading large artifact Chungliao-xlm-roberta-sentiment:v9, 682.23MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:1.5


In [None]:
pre_trained_model_name = artifact_dir

In [None]:
model = SentimentClassifier(pre_trained_model_name, num_classes=2)

In [None]:
probe_metrics_chungli_ao_high_lr = {}
for i in range(model.pre_trained_model.config.num_hidden_layers):
  print(f"probing layer {i}/{(model.pre_trained_model.config.num_hidden_layers)}")
  # Initialize model, loss function, and optimizer
  probing_model = ProbingModel(model.pre_trained_model.config.hidden_size, num_classes=2)
  loss_fn = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(probing_model.parameters(), lr=0.001)
  probe_metrics_i_chungli_ao_high_lr = train_probe(probing_model, model.pre_trained_model, train_loader, test_loader, loss_fn, optimizer, DEVICE, num_epochs=1, hidden_state_layer_index=i)
  probe_metrics_chungli_ao_high_lr[i + 1] = probe_metrics_i_chungli_ao_high_lr

probing layer 0/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.63it/s]


Epoch 1/1, Training Loss: 0.4896329343318939


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.98it/s]


Epoch 1/1, Validation Loss: 0.7616992554394528, Validation Accuracy: 0.5079
probing layer 1/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.70it/s]


Epoch 1/1, Training Loss: 0.3229140341281891


Validation Epoch 1/1: 100%|██████████| 256/256 [00:20<00:00, 12.77it/s]


Epoch 1/1, Validation Loss: 0.6748345336527564, Validation Accuracy: 0.7045
probing layer 2/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:54<00:00,  9.77it/s]


Epoch 1/1, Training Loss: 0.21742366254329681


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.89it/s]


Epoch 1/1, Validation Loss: 0.5793233082804363, Validation Accuracy: 0.7692
probing layer 3/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.76it/s]


Epoch 1/1, Training Loss: 0.1742027848958969


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.88it/s]


Epoch 1/1, Validation Loss: 0.5455051457684021, Validation Accuracy: 0.7734
probing layer 4/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.75it/s]


Epoch 1/1, Training Loss: 0.33215779066085815


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.86it/s]


Epoch 1/1, Validation Loss: 0.5179167930909898, Validation Accuracy: 0.7853
probing layer 5/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.71it/s]


Epoch 1/1, Training Loss: 0.08808115869760513


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.83it/s]


Epoch 1/1, Validation Loss: 0.509237514808774, Validation Accuracy: 0.7756
probing layer 6/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.73it/s]


Epoch 1/1, Training Loss: 0.254708856344223


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.85it/s]


Epoch 1/1, Validation Loss: 0.4595531689701602, Validation Accuracy: 0.7939
probing layer 7/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.75it/s]


Epoch 1/1, Training Loss: 0.16838772594928741


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.87it/s]


Epoch 1/1, Validation Loss: 0.4384326642029919, Validation Accuracy: 0.8017
probing layer 8/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.74it/s]


Epoch 1/1, Training Loss: 0.03692241385579109


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.89it/s]


Epoch 1/1, Validation Loss: 0.5141817245166749, Validation Accuracy: 0.7629
probing layer 9/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.72it/s]


Epoch 1/1, Training Loss: 0.026045143604278564


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.86it/s]


Epoch 1/1, Validation Loss: 0.5352793212223332, Validation Accuracy: 0.7753
probing layer 10/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.72it/s]


Epoch 1/1, Training Loss: 0.02360418252646923


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.87it/s]


Epoch 1/1, Validation Loss: 0.48344978151726536, Validation Accuracy: 0.8288
probing layer 11/12


Training Epoch 1/1: 100%|██████████| 537/537 [00:55<00:00,  9.75it/s]


Epoch 1/1, Training Loss: 0.048343077301979065


Validation Epoch 1/1: 100%|██████████| 256/256 [00:19<00:00, 12.88it/s]

Epoch 1/1, Validation Loss: 0.5203596868232125, Validation Accuracy: 0.8254





In [None]:
probe_metrics_chungli_ao_high_lr

{1: {'train_losses': [0.48471641540527344],
  'val_losses': [0.7607034912798554],
  'val_accuracies': [0.5079365079365079]},
 2: {'train_losses': [0.15657581388950348],
  'val_losses': [0.6668008791748434],
  'val_accuracies': [0.7045177045177046]},
 3: {'train_losses': [0.30366814136505127],
  'val_losses': [0.5709647599433083],
  'val_accuracies': [0.7692307692307693]},
 4: {'train_losses': [0.1741148978471756],
  'val_losses': [0.5571531471214257],
  'val_accuracies': [0.76996336996337]},
 5: {'train_losses': [0.3348071873188019],
  'val_losses': [0.5246089289721567],
  'val_accuracies': [0.7841269841269841]},
 6: {'train_losses': [0.05789889022707939],
  'val_losses': [0.4935682862997055],
  'val_accuracies': [0.7785103785103785]},
 7: {'train_losses': [0.05695723369717598],
  'val_losses': [0.4669395922101103],
  'val_accuracies': [0.7921855921855921]},
 8: {'train_losses': [0.14471182227134705],
  'val_losses': [0.44384907954372466],
  'val_accuracies': [0.8007326007326008]},
 9: