In [None]:
import os
import json
import numpy as np
import pandas as pd
from tqdm import trange, tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from src import *

Load and preprocess dataset

In [None]:
from datasets import load_from_disk, disable_caching
disable_caching()
DATASET_DIR = "./data/human_1k_demo"
demo_dataset = load_from_disk(DATASET_DIR)

celltype_map_dict_dir = os.path.join(DATASET_DIR, "name_id_dict.json")
with open(celltype_map_dict_dir) as file:
    celltype_map = json.load(file)
celltype_map = {v:k for k,v in celltype_map.items()}

Define model

In [None]:
dtype = torch.float32

torch.set_default_dtype(dtype)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

PRETRAINED_LLM_PATH = ""

model = netscInterpreter(
    llm="llama",
    pretrained_llm=PRETRAINED_LLM_PATH,
    num_classes=1000,
    init_range=0.02)

print(model)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"model has {total_params / 1e6} Million params")
print(f"model has {total_params} params\n")
print(f"model has {trainable_params} trainable params\n")

In [None]:
lr = 5e-5
per_gpu_batch_size = 8
gradient_accumulation_steps = 1
epochs = 8
warmup_steps = 1000

gpu_ids = [1,2]
num_gpus = len(gpu_ids)
os.environ['CUDA_VISIBLE_DEVICES'] = ",".join(map(str, gpu_ids))

splited_dataset = demo_dataset.train_test_split(train_size=0.9, test_size=0.1, seed=0)

fit_dataset = {"train": Dataset(splited_dataset["train"], dtype=dtype, seed=0),
               "test": Dataset(splited_dataset["test"], dtype=dtype, seed=0)}

if num_gpus>1:
    import torch.multiprocessing as mp
    from functools import partial
    os.environ['MASTER_ADDR'] = '127.0.1.1'
    os.environ['MASTER_PORT'] = '8848'
    fit_func = partial(model.fit, epochs=epochs, num_workers=4, num_gpus=num_gpus, lr=lr, batch_size=per_gpu_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=warmup_steps, log_wandb=True, display_metrics=["accuracy", "f1"])
    mp.set_start_method("spawn", force=True)
    manager = mp.Manager()
    results = manager.dict()
    try:
        mp.spawn(fit_func, args=(fit_dataset, results,), nprocs=num_gpus, join=True)
        model: netscInterpreter = results[0]["model"]
    except KeyboardInterrupt as e:
        for proc in mp.active_children():
            proc.terminate()
else:
    results = model.fit(0, fit_dataset, None, epochs=epochs, num_workers=4, num_gpus=num_gpus, lr=lr, batch_size=per_gpu_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warnup_steps=warmup_steps, log_wandb=True, display_metrics=["accuracy", "f1"])
    model: netscInterpreter = results["model"]

In [None]:
with torch.no_grad():
    model.val(
        fit_dataset["test"],
        celltype_map,
        device="cuda:0",
        model=True,
    )

In [None]:
model.print_metrics()

In [None]:
model.plot_confusion_matrix()

In [None]:
model.plot_umap()