In [2]:
import fire
from esm_ecn.model import LitSAE
from esm_ecn.data import train_data_loader, val_data_loader, test_data_loader
from esm_ecn.train import setup_experiment, load_best_checkpoint
from esm_ecn.constants import DATA_FOLDER

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_type = "esmc_300m"
batch_size = 2048
cls = True
experiment_name = None
project = "esm_cls_SAE"
resume = False
wandb_debug = True
accelerator = "cuda"
epochs = 3
dict_exp = 8

In [4]:
train_loader = train_data_loader(model_type, batch_size, cls)
val_loader = val_data_loader(model_type, batch_size, cls)
test_loader = test_data_loader(model_type, batch_size, cls)

input_dim = train_loader.dataset[0][0].shape[0]
model = LitSAE(model_dim=input_dim, dict_dim=dict_exp * input_dim, sparsity_coefficient=1.0)

trainer, experiment_name = setup_experiment(experiment_name, project, resume, wandb_debug, accelerator, epochs)
print(experiment_name)
checkpoint_path = DATA_FOLDER / 'checkpoints' / experiment_name / "best-checkpoint.ckpt"

if resume:
    raise NotImplementedError("Resuming training is not yet implemented")

trainer.fit(model, train_loader, val_loader)

Loading train embeddings
Loaded embeddings, shape: torch.Size([178302, 960])
Loading train labels
Loaded labels, shape: torch.Size([178302, 4793])
Loading dev embeddings
Loaded embeddings, shape: torch.Size([23010, 960])
Loading dev labels
Loaded labels, shape: torch.Size([23010, 4793])
Loading test embeddings
Loaded embeddings, shape: torch.Size([22183, 960])
Loading test labels
Loaded labels, shape: torch.Size([22183, 4793])


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


/home/ishan/miniforge3/envs/ecn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ishan/miniforge3/envs/ecn/lib/python3.10/site- ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX 6000 Ada Generation') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type   | Params | Mode 
----------------------------------------------
0 | encoder_DF | Linear | 7.4 M  

dhts3kgg
                                                                           

/home/ishan/miniforge3/envs/ecn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.
/home/ishan/miniforge3/envs/ecn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Epoch 2: 100%|██████████| 88/88 [00:03<00:00, 22.77it/s, v_num=3kgg, train/reconstruction=7.31e-5, train/sparsity=54.90, train/loss=0.477, val/reconstruction=7.33e-5, val/sparsity=65.50, val/loss=0.483]  

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 88/88 [00:05<00:00, 17.13it/s, v_num=3kgg, train/reconstruction=7.31e-5, train/sparsity=54.90, train/loss=0.477, val/reconstruction=7.33e-5, val/sparsity=65.50, val/loss=0.483]


In [5]:
model = LitSAE.load_from_checkpoint(checkpoint_path, model_dim=input_dim, dict_dim=dict_exp * input_dim, sparsity_coefficient=1.0)
trainer.test(model, test_loader)

/home/ishan/miniforge3/envs/ecn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/ishan/miniforge3/envs/ecn/lib/python3.10/site- ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/ishan/miniforge3/envs/ecn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 11/11 [00:00<00:00, 23.22it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test/loss           0.48292267322540283
   test/reconstruction     7.340726006077603e-05
      test/sparsity          65.19487762451172
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/reconstruction': 7.340726006077603e-05,
  'test/sparsity': 65.19487762451172,
  'test/loss': 0.48292267322540283}]

In [6]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np
import torch

# Extract features and labels from the train_loader
X_train = []
X_train_dense = []
X_train_og = []
y_train = []
X_test = []
X_test_dense = []
X_test_og = []
y_test = []
with torch.no_grad():
    model.eval()
    model.to('cuda')
    for batch in train_loader:
        features, labels = batch
        x_sparse, x_dense = model(features)
        X_train.append(x_sparse.cpu().numpy())
        X_train_dense.append(x_dense.cpu().numpy())
        X_train_og.append(features.cpu().numpy())
        # top_k_codes_BKF = model.top_k(features, k=5)
        y_train.append(labels.cpu().numpy())

    for batch in test_loader:
        features, labels = batch
        x_sparse, x_dense = model(features)
        X_test.append(x_sparse.cpu().numpy())
        X_test_dense.append(x_dense.cpu().numpy())
        X_test_og.append(features.cpu().numpy())
        y_test.append(labels.cpu().numpy())

X_train = np.vstack(X_train)
X_train_dense = np.vstack(X_train_dense)
y_train = np.concatenate(y_train)

# Extract features and labels from the test_loader
X_test = np.vstack(X_test)
X_test_dense = np.vstack(X_test_dense)
y_test = np.concatenate(y_test)

In [7]:
y_train = torch.tensor([torch.nonzero(torch.tensor(y))[0].item() for y in y_train])
print(y_train.shape)
y_test = torch.tensor([torch.nonzero(torch.tensor(y))[0].item() for y in y_test])
print(y_test.shape)

torch.Size([178302])
torch.Size([22183])


In [8]:
print(len(torch.unique(y_train)))
labels, counts = torch.unique(y_train, return_counts=True)
final_labels = labels[counts > 20]
print(len(final_labels))

4277
958


In [9]:
def relabel_labels(labels, final_labels):
    label_map = {label.item(): idx for idx, label in enumerate(final_labels)}
    new_labels = []
    for label in labels:
        if label.item() in label_map:
            new_labels.append(label_map[label.item()])
        else:
            new_labels.append(len(final_labels))
    return torch.tensor(new_labels)

y_train = relabel_labels(y_train, final_labels)
y_test = relabel_labels(y_test, final_labels)

In [10]:
# Train a logistic regression model
log_reg = LogisticRegression(multi_class='multinomial', 
    solver='lbfgs',
    max_iter=50,
    verbose=1
)
log_reg.fit(X_train_dense, y_train)



# Predict and evaluate the model
y_pred = log_reg.predict(X_test_dense)
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy:.4f}")



KeyboardInterrupt: 

In [11]:
# Train a logistic regression model
log_reg = LogisticRegression(multi_class='multinomial', 
    solver='lbfgs',
    max_iter=50,
    verbose=1
)
log_reg.fit(X_train, y_train)



# Predict and evaluate the model
y_pred = log_reg.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy:.4f}")



Test Accuracy: 0.0869


In [11]:
from sklearn.neighbors import KNeighborsClassifier

# Train a KNN classifier
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(X_train_dense, y_train)

# Predict and evaluate the model
y_pred_knn = knn.predict(X_test_dense)
accuracy_knn = accuracy_score(y_test, y_pred_knn)
print(f"KNN Test Accuracy: {accuracy_knn:.4f}")

KNN Test Accuracy: 0.4006


In [12]:
from sklearn.neighbors import KNeighborsClassifier

# Train a KNN classifier
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(X_train, y_train)

# Predict and evaluate the model
y_pred_knn = knn.predict(X_test)
accuracy_knn = accuracy_score(y_test, y_pred_knn)
print(f"KNN Test Accuracy: {accuracy_knn:.4f}")

KNN Test Accuracy: 0.7562
