In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

# Checking CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
data_root = "/home/kgulbarg/thesis/INR_SCHIC/Nagano_Data"
path = "/home/kgulbarg/thesis/INR_SCHIC/mm10.main.nochrM.chrom.sizes"
resolution = 1_000_000
n_bins = {}
LABEL_TO_ID = {"G1": 0, "early_S": 1, "mid_S": 2, "late_S": 3}
cells = []

In [None]:
# read_chrom_sizes

with open(path, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 2:
                continue
            chr_name, length = parts[0], int(parts[1])
            max_len = length // resolution
            n_bins[chr_name] = max_len + 1  # include trailing partial

print(n_bins)
print(f'total bins:  {sum(n_bins.values())}')

{'chr1': 198, 'chr2': 183, 'chr3': 161, 'chr4': 157, 'chr5': 153, 'chr6': 150, 'chr7': 153, 'chr8': 132, 'chr9': 125, 'chr10': 131, 'chr11': 123, 'chr12': 122, 'chr13': 121, 'chr14': 126, 'chr15': 105, 'chr16': 99, 'chr17': 96, 'chr18': 91, 'chr19': 62, 'chrX': 172}
total bins:  2660


In [None]:
def get_data_prepared(cell_dir):
    # load_cell_contacts_and_normalise    
    coords_list = []
    targets_list = []
    
    # for each of 20 chromosomes (1 - 19 and X)
    for fname in sorted(os.listdir(cell_dir)):

        if not (fname.startswith("chr") and fname.endswith(".txt")):
            continue
        
        chr_name = fname[:-4]  # 'chr16' from 'chr16.txt'
        if chr_name == 'chrY': # no contacts in chrY and not in bin sizes
            continue
        nb = n_bins[chr_name]
    
        fp = os.path.join(cell_dir, fname)
        # header on 0th line: bin1    bin2    count
        df = pd.read_csv(fp, sep=r"\s+", header=0, comment="#")

        # checking for reversed pairs:
        tmp = df[df["bin1"] != df["bin2"]].copy()  # ignore diagonal
        tmp["min"] = np.minimum(tmp["bin1"], tmp["bin2"])
        tmp["max"] = np.maximum(tmp["bin1"], tmp["bin2"])
        tmp["ori"] = np.where(tmp["bin1"] < tmp["bin2"], "ij", "ji")
        
        has_reversed = (tmp.groupby(["min", "max"])["ori"].nunique().eq(2)).any()
        # print("Reversed-order duplicates?", bool(has_reversed))
        if bool(has_reversed):
            print("Reversed-order duplicates?", bool(has_reversed), cell_dir, chr_name)

        # aggregate scattered reads
        df_agg = df.groupby(["bin1", "bin2"], as_index=False)["count"].sum()
    
        arr = df_agg[["bin1","bin2","count"]].to_numpy(int)
        
        b1 = arr[:,0]
        b2 = arr[:,1]
        cnt = arr[:,2]
    
        # normalize bin indices to [-1, 1] - linear scaling
        x = 2.0 * (b1.astype(np.float32) / (nb - 1)) - 1.0
        y = 2.0 * (b2.astype(np.float32) / (nb - 1)) - 1.0
        coords_list.append(np.stack([x, y], axis=1))
    
        # transform contact counts: log(1 + cnt) - non linear transform to stabilize variance
        # OR Anscombe: targets_list.append((2.0*np.sqrt(cnt + 0.375))[:, None])
        targets_list.append(np.log1p(cnt)[:, None])
    
    coords = np.concatenate(coords_list, axis=0).astype(np.float32)
    targets = np.concatenate(targets_list, axis=0).astype(np.float32)

    return coords, targets

# print("coords shape:", coords_i.shape, "targets shape:", targets_i.shape)
# print("coords example (first 5):\n", coords_i[:5])
# print("targets example (first 5):\n", targets_i[:5].ravel())

In [None]:
def  train_latent(coords, targets):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # print("Using device:", device)
    
    # data -> numpy to torch
    X = torch.from_numpy(coords).to(device)   # shape (N, 2)
    y = torch.from_numpy(targets).to(device)  # shape (N, 1)
    
    # simple feedforward with bottleneck (lesser dim) (embedding layer)
    # currently not using bottleneck
    embedding_dim = 256
    
    class ContactNet(nn.Module):
        def __init__(self, emb_dim):
            super().__init__()
            self.fc1 = nn.Linear(2, 64)
            self.fc2 = nn.Linear(64, 512)
            self.fc3 = nn.Linear(512, emb_dim)
            self.fc4 = nn.Linear(emb_dim, 1)    # predict log(count)
        
        def forward(self, x):
            h_0 = torch.relu(self.fc1(x))
            h_1 = torch.relu(self.fc2(h_0))
            emb = torch.relu(self.fc3(h_1))  # <-- embedding matric: one vector (size 256) per coord-target pair
            out = self.fc4(emb)
            return out, emb
    
    model = ContactNet(embedding_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=7e-4)
    loss_fn = nn.MSELoss()

    # minimal training loop
    for epoch in range(100):  # increase for real training
        optimizer.zero_grad()
        preds, _ = model(X)
        loss = loss_fn(preds, y)
        loss.backward()
        optimizer.step()
        if epoch == 99:
            print(f"epoch {epoch}, loss {loss.item():.4f}")
    
    # after training: get embedding for this cell
    with torch.no_grad():
        _, emb = model(X)
        cell_vector = emb.mean(dim=0)  # average across all coords
        
    return cell_vector.detach().cpu().numpy() 


In [None]:
for label in sorted(os.listdir(data_root)):
    print(f"Started: Making latenet representations for {label} cells.")
    label_dir = os.path.join(data_root, label)
    
    if not os.path.isdir(label_dir) or label not in LABEL_TO_ID:
        continue
    
    for cell in sorted(os.listdir(label_dir)):
        cell_dir = os.path.join(label_dir, cell)
        
        if os.path.isdir(cell_dir):            
            coords, targets = get_data_prepared(cell_dir)
            cell_vector = train_latent(coords, targets)
            latent_path = cell_dir + "_latent.npy"
            np.save(latent_path, cell_vector) # loaded = np.load("cell_vector.npy")   

Using device: cuda
epoch 99, loss 0.0299
Using device: cuda
epoch 99, loss 0.0333
Using device: cuda
epoch 99, loss 0.0186
Using device: cuda
epoch 99, loss 0.0209
Using device: cuda
epoch 99, loss 0.0164
Using device: cuda
epoch 99, loss 0.0422
Using device: cuda
epoch 99, loss 0.0206
Using device: cuda
epoch 99, loss 0.0195
Using device: cuda
epoch 99, loss 0.0204
Using device: cuda
epoch 99, loss 0.0195
Using device: cuda
epoch 99, loss 0.0181
Using device: cuda
epoch 99, loss 0.0209
Using device: cuda
epoch 99, loss 0.0328
Using device: cuda
epoch 99, loss 0.0368
Using device: cuda
epoch 99, loss 0.0215
Using device: cuda
epoch 99, loss 0.0214
Using device: cuda
epoch 99, loss 0.0219
Using device: cuda
epoch 99, loss 0.0216
Using device: cuda
epoch 99, loss 0.0213
Using device: cuda
epoch 99, loss 0.0208
Using device: cuda
epoch 99, loss 0.0418
Using device: cuda
epoch 99, loss 0.0138
Using device: cuda
epoch 99, loss 0.0163
Using device: cuda
epoch 99, loss 0.0382
Using device: cu