In [None]:
import pandas as pd
import numpy as np 
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
#add torch reproducibility 

In [None]:
xiang_filtered = pd.read_csv("kr_at_enriched_dataset.csv")
xiang_filtered_KR_embeddings = torch.load("ATTest_KR.pt")
xiang_filtered_AT_embeddings = torch.load("ATTest_AT.pt")
xiang_filtered_KR_embeddings = xiang_filtered_KR_embeddings["embeddings"]
xiang_filtered_AT_embeddings = xiang_filtered_AT_embeddings["embeddings"]
#test1 xiang, test2 smash

In [None]:
xiang_filtered["alphasub"] = [
    1 if any(char.isdigit() for char in str(annotation)) else 0
    for annotation in xiang_filtered["Annotation"]
]

In [None]:
len(xiang_filtered["at_sequence"][100])

In [None]:
xiang_filtered

In [None]:
xiang_filtered

In [None]:
#average pooling. [len_of_seq, 1, 1536]

xiang_filtered_KR_embeddings = [x.mean(dim=1).squeeze(0) for x in xiang_filtered_KR_embeddings]
xiang_filtered_AT_embeddings = [x.mean(dim=1).squeeze(0) for x in xiang_filtered_AT_embeddings]

In [None]:
#stacking
xiang_AT_embeddings = torch.stack(xiang_filtered_AT_embeddings)
xiang_embeddings = xiang_AT_embeddings

In [None]:
#outputs need to be converted to numerical values.

annotations_unique = xiang_filtered["alphasub"].unique()
annotations_unique.sort() #sort alphabetically 

annotation_enumerated = {x: i for i, x in enumerate(annotations_unique)}
print(annotation_enumerated)

In [None]:
xiang_filtered["AnnotationEnumerated"] = xiang_filtered["alphasub"].map(annotation_enumerated)

In [None]:
print(xiang_filtered["AnnotationEnumerated"])

In [None]:
xiang_filtered_np = xiang_filtered["AnnotationEnumerated"].to_list()

In [None]:
#xiang_filtered_tensor = torch.tensor(xiang_filtered_np)
#xiang_filtered_tensor = [torch.tensor(x, dtype=torch.long) for x in xiang_filtered_np]
xiang_filtered_tensor = torch.tensor(xiang_filtered_np, dtype=torch.long)

In [None]:
from sklearn.model_selection import train_test_split

x_train_tensor, x_test_tensor, y_train_tensor, y_test_tensor = train_test_split(
    xiang_embeddings,
    xiang_filtered_tensor,
    test_size = 0.2,
    random_state=1,
    stratify=xiang_filtered_tensor
)

mu, sigma = x_train_tensor.mean(0), x_train_tensor.std(0) + 1e-9
x_train_tensor = (x_train_tensor - mu) / sigma
x_test_tensor = (x_test_tensor - mu) / sigma

print("x train len")
print(len(x_train_tensor))
print("y train len")
print(len(y_train_tensor))

In [None]:
xiang_embeddings[0].shape[0]

In [None]:

#add dropout_rate. ==
class kr_predict(nn.Module):
    def __init__(self):
        super(kr_predict, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(xiang_embeddings[0].shape[0], 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        self.out = nn.Linear(256, 2)
    def forward(self, x):
        #x = x.view(x.size(0), -1) # flatten so we're removing 
        x = self.hidden(x)
        x = self.out(x)
        return x      



'''
class kr_predict(nn.Module):
    def __init__(self):
        super(kr_predict, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(1536, 512),     # Slightly wider first layer
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(), 
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
        )
        self.out = nn.Linear(128, 9)
    def forward(self, x):
        #x = x.view(x.size(0), -1) # flatten so we're removing - modified so its done w/ squeeze
        x = self.hidden(x)
        x = self.out(x)
        return x 
'''

In [None]:
#xiang_embeddings[0].shape[1]

In [None]:
model = kr_predict()

In [None]:
#apple silicon

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

if torch.backends.mps.is_available():
    print("MPS is available! Using Apple Silicon GPU.")
else:
    print("MPS is not available. CPU Fallback.")

In [None]:
from collections import Counter
class_counts = Counter(y_train_tensor.numpy())
total_samples = sum(class_counts.values())
num_classes = 2

class_weights = torch.tensor([
    total_samples / (num_classes * class_counts.get(i, 1)) for i in range(num_classes)
], dtype=torch.float32).to(device)

loss = nn.CrossEntropyLoss(weight=class_weights.to(device))
adam = optim.Adam(model.parameters(), lr=0.00001)
scheduler = optim.lr_scheduler.StepLR(adam, step_size=200, gamma=0.01)


In [None]:
x_train_tensor

In [None]:
x_train_tensor = x_train_tensor
y_train_tensor = (y_train_tensor).long()

In [None]:
print(y_train_tensor)

In [None]:
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time

plt.ioff()
epochs_list = []
losses_list = []

batch_size = 8
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

for epoch in range(500):
    model.train()
    epoch_loss = 0.0
    for seqs, anns in train_loader:
        seqs = seqs.to(device)
        anns = anns.to(device)
        output = model(seqs)
        output_loss = loss(output, anns)
        adam.zero_grad()
        output_loss.backward()
        adam.step()
        epoch_loss += output_loss.item() * seqs.size(0)
    scheduler.step()
    avg_loss = epoch_loss / len(train_dataset)
    epochs_list.append(epoch + 1)
    losses_list.append(avg_loss)
    clear_output(wait=True)
    fig, ax = plt.subplots(figsize=(14, 8))
    ax.plot(epochs_list, losses_list, 'b-', linewidth=1.5, alpha=0.9)
    ax.set_xlim(0, 1000)
    if len(losses_list) > 1:
        loss_min = min(losses_list)
        loss_max = max(losses_list)
        loss_range = loss_max - loss_min
        ax.set_ylim(max(0, loss_min - 0.1 * loss_range), loss_max + 0.1 * loss_range)
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Average Loss', fontsize=12)
    ax.set_title(f'LIVE Training Loss - Epoch {epoch+1} | Loss: {avg_loss:.4f}', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.plot(epochs_list[-1], losses_list[-1], 'ro', markersize=6, alpha=0.8)
    if len(losses_list) > 1:
        initial_loss = losses_list[0]
        improvement = ((initial_loss - avg_loss) / initial_loss) * 100
        ax.text(0.02, 0.98, f'Initial: {initial_loss:.4f}\nCurrent: {avg_loss:.4f}\nImprovement: {improvement:.1f}%', 
                transform=ax.transAxes, fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    progress = (epoch + 1) / 1000
    ax.axvline(x=epoch + 1, color='red', alpha=0.3, linewidth=2)
    ax.text(epoch + 1, ax.get_ylim()[1] * 0.95, f'{progress:.1%}', 
            ha='center', fontsize=10, color='red', fontweight='bold')
    plt.tight_layout()
    plt.show()


In [None]:
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
#dropping last bc we are using batchnorm, so it needs >1 batch size 

all_predictions = []
all_targets = []
def accuracy():
    model.eval()
    correct = 0
    total = 0


    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs) #model predictions [batch_size, 9]
            values, predicted = torch.max(outputs.data, 1)
            #values has highest score for each sample in batch
            #the predicted part has the classes w/ highest score for each sample
            total += targets.size(0) #add batch size
            correct += (predicted == targets).sum().item()

            #for classification report
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    return 100 * correct/total

accuracy()

In [None]:

from sklearn.metrics import classification_report
import numpy as np


class_names = ['0','1']
print(classification_report(all_targets, all_predictions, target_names=class_names))