### Implementing CORAL


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [None]:
def transform_labels_to_binary(y, num_classes):    
    transformed_labels = torch.zeros(y.shape[0], num_classes-1)
    
    for idx, label in enumerate(y):
        transformed_labels[idx, 0:label] = 1
        
    return transformed_labels

In [None]:
y = torch.tensor([0, 2, 3, 1, 2])

num_classes = len(y.unique())

transform_labels_to_binary(y, num_classes)

### Building Model

In [None]:
class OrdinalClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(OrdinalClassifier, self).__init__()
        
        # define layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes - 1)
        
        # independent bias 
        self.biases = nn.Parameter(torch.zeros(num_classes - 1))

    def forward(self, x):
        
        # compute layer output
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        output = self.fc3(x)
        
        # add bias term
        output += self.biases
        
        # return sigmoid
        return torch.sigmoid(output)

In [None]:
loss = nn.BCELoss()


X = torch.rand((5, 2))
y = torch.tensor([0, 2, 3, 1, 2])

In [None]:
transformed_labels = transform_labels_to_binary(y, len(y.unique()))

In [None]:
model = OrdinalClassifier(2, 5, num_classes)

In [None]:
output = model(X)
output

In [None]:
loss_value = loss(output, transformed_labels)

In [None]:
for i in range(100):
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    optimizer.zero_grad()
    
    # forward pass
    output = model(X)
    
    # compute loss
    loss_value = loss(output, transformed_labels)
    
    # backward pass
    loss_value.backward()
    
    # update weights
    optimizer.step()
    
    if i % 10 == 0:
        print(f"Epoch {i}, Loss: {loss_value.item()}")

In [None]:
model.eval()
output = model(X)
output

### CORAL For Sentense Classification

In [None]:
# pip install sentence-transformers

In [1]:
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')

df = pd.read_pickle("data/review_data.pickle")
df.head(2)

Unnamed: 0,Review,Rating
0,not bad couple nights no nice looking hotel ou...,3
1,"wo n't planned trip group 11 including, booked...",2


In [3]:
embeddings = model.encode(
    df['Review'].tolist(),
    show_progress_bar=True,
    convert_to_tensor=True
)

y = df.Rating.values

Batches: 100%|██████████| 223/223 [00:35<00:00,  6.35it/s]


In [4]:
x_train, x_test, y_train, y_test = train_test_split(embeddings, y, test_size=0.2)

In [11]:
device = torch.device("mps")

In [12]:
def transform_labels_to_binary(y, num_classes):    
    transformed_labels = torch.zeros(y.shape[0], num_classes-1)
    
    for idx, label in enumerate(y):
        transformed_labels[idx, 0:label] = 1
        
    return transformed_labels

embeddings_train = torch.tensor(x_train).to(device)
embeddings_test = torch.tensor(x_test).to(device)

y_train = torch.tensor(y_train).to(device)
y_test = torch.tensor(y_test).to(device)

num_classes = len(y_test.unique())

transformed_labels = transform_labels_to_binary(y_train, num_classes).to(device)

  embeddings_train = torch.tensor(x_train).to(device)
  embeddings_test = torch.tensor(x_test).to(device)
  y_train = torch.tensor(y_train).to(device)
  y_test = torch.tensor(y_test).to(device)


In [13]:
class OrdinalClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(OrdinalClassifier, self).__init__()
        
        # define layers
        self.fc1 = nn.Linear(input_size, hidden_size).to('mps')
        self.fc2 = nn.Linear(hidden_size, hidden_size).to('mps')
        self.fc3 = nn.Linear(hidden_size, num_classes - 1).to('mps')
        
        # independent bias 
        self.biases = nn.Parameter(torch.zeros(num_classes - 1)).to('mps')

    def forward(self, x):
        
        # compute layer output
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        output = self.fc3(x)
        
        # add bias term
        output += self.biases
        
        # return sigmoid
        return torch.sigmoid(output)

In [14]:
model = OrdinalClassifier(
    embeddings_train.shape[1],
    30,
    num_classes
)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [21]:
num_epochs = 10
batch_size = 64
num_batches = len(embeddings_train) // batch_size

In [22]:
# Training loop
for epoch in range(num_epochs):
    running_loss = 0.0
    
    # Mini-batch iteration
    for batch in range(num_batches):
        start_idx = batch * batch_size
        end_idx = min((batch + 1) * batch_size, len(embeddings_train))
        inputs = embeddings_train[start_idx:end_idx]
        
        # Generate binary labels for the mini-batch
        binary_labels = transformed_labels[start_idx:end_idx]
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        
        # Calculate the loss
        loss = criterion(outputs, binary_labels.float())
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    # accuracy
    test_outputs = model(embeddings_test)
    test_binary = (test_outputs>0.5).to(int)
    test_labels = torch.sum(test_binary, 1)    
    accuracy = float(sum(test_labels == y_test)/y_test.shape[0])
    
    # Print average loss for each epoch
    print(f"Epoch {epoch+1}, Loss: {round(running_loss / num_batches, 3)}, Accuracy = {round(accuracy, 3)}")

Epoch 1, Loss: 0.006, Accuracy = 0.383
Epoch 2, Loss: 0.006, Accuracy = 0.393
Epoch 3, Loss: 0.005, Accuracy = 0.385
Epoch 4, Loss: 0.006, Accuracy = 0.388
Epoch 5, Loss: 0.004, Accuracy = 0.375
Epoch 6, Loss: 0.003, Accuracy = 0.384
Epoch 7, Loss: 0.006, Accuracy = 0.381
Epoch 8, Loss: 0.006, Accuracy = 0.391
Epoch 9, Loss: 0.004, Accuracy = 0.391
Epoch 10, Loss: 0.004, Accuracy = 0.385
