# Train word2vec

In [30]:
import pickle
import numpy as np
import pandas as pd

from utils.knowledge_db import TOKENS

In [39]:
#with open(r"C:\Users\Felix\code\uni\UniVie\master-thesis\data\train_data\dev_train_9.0k_data.pkl", 'rb') as f:
with open("/export/share/krausef99dm/data/data_test/test_9.0k_data.pkl", 'rb') as f:
#with open("/export/share/krausef99dm/data/data_train/train_9.0k_data.pkl", 'rb') as f:
    rna_data, target_ids, targets, targets_bin = pickle.load(f)

In [24]:
# Load ptr data
with open("/export/share/krausef99dm/data/ptr_data/ptr_data.pkl", 'rb') as f:
    raw_data = pickle.load(f)

In [31]:
indices_seq_train = pd.read_csv("/export/share/krausef99dm/data/data_train/train_9.0k_indices.csv")

In [40]:
len(rna_data)

26829

In [21]:
rna_data[0].shape

torch.Size([3382, 4])

## Sequence only

### Raw ptr_data

In [60]:
seq_kmers = []
for idx in indices_seq_train.identifier.unique():
    seq = raw_data[idx]["fasta"]
    seq_kmers = ["".join(seq[i:i+k]) for i in range(len(list(seq)) - k + 1)]
    seq_kmers.append(seq_kmers)

In [67]:
len(seq_kmers)

6324

### If using encoded rna_data

In [41]:
# remove all columns except the sequence
rna_data = [rna_data[i][:, 0] for i in range(len(rna_data))]

In [42]:
# build dictionary to map integers to tokens
int2token = {i+1: token for i, token in enumerate(TOKENS)}

In [43]:
def generate_kmers(rna_data, k=3):
    rna_data_kmers = []
    for seq in rna_data:
        seq = seq.tolist()
        seq = [int2token.get(i) for i in seq]
        seq_kmers = ["".join(seq[i:i+k]) for i in range(len(seq) - k + 1)]
        rna_data_kmers.append(seq_kmers)
        
    return rna_data_kmers

In [44]:
seq_kmers = generate_kmers(rna_data)

### Train word2vec

In [68]:
import gensim

embedding_size = 64

model = gensim.models.Word2Vec(seq_kmers, vector_size=embedding_size, window=12, min_count=0, workers=4)

In [69]:
import torch
import torch.nn as nn

# Get vocabulary and embeddings
vocab = list(model.wv.index_to_key)
embedding_matrix = torch.tensor(np.array([model.wv[word] for word in vocab], dtype=np.float32))

# Create k-mer to index mapping
kmer_to_index = {kmer: idx for idx, kmer in enumerate(vocab)}

In [72]:
kmer_to_index

{'TTT': 0,
 'AAA': 1,
 'CTG': 2,
 'CAG': 3,
 'TGG': 4,
 'AGA': 5,
 'CCT': 6,
 'CCA': 7,
 'GAA': 8,
 'AAG': 9,
 'GGA': 10,
 'GAG': 11,
 'CCC': 12,
 'AGG': 13,
 'GCC': 14,
 'TGA': 15,
 'TGT': 16,
 'TCT': 17,
 'AGC': 18,
 'GGC': 19,
 'CTT': 20,
 'GCT': 21,
 'GGG': 22,
 'TGC': 23,
 'CTC': 24,
 'GTG': 25,
 'ATG': 26,
 'TTC': 27,
 'ATT': 28,
 'TCC': 29,
 'TCA': 30,
 'TTG': 31,
 'ACA': 32,
 'AAT': 33,
 'GCA': 34,
 'CAA': 35,
 'CAT': 36,
 'CAC': 37,
 'AGT': 38,
 'ACT': 39,
 'ACC': 40,
 'TTA': 41,
 'TAA': 42,
 'TAT': 43,
 'GAT': 44,
 'AAC': 45,
 'GTT': 46,
 'GAC': 47,
 'ATA': 48,
 'ATC': 49,
 'GGT': 50,
 'GTC': 51,
 'CTA': 52,
 'TAC': 53,
 'GTA': 54,
 'CGG': 55,
 'CCG': 56,
 'TAG': 57,
 'GCG': 58,
 'CGC': 59,
 'CGA': 60,
 'ACG': 61,
 'CGT': 62,
 'TCG': 63}

In [70]:
# Store model
model_data = {"kmer_to_index": kmer_to_index, "embedding_matrix": embedding_matrix}
with open("/export/share/krausef99dm/data/w2v_model_data.pkl", 'wb') as f:
    pickle.dump(model_data, f)

In [None]:
# Load model
#with open("/export/share/krausef99dm/data/w2v_model_data_dev.pkl", 'rb') as f:
    #model_data = pickle.load(f)

### Apply embeddings

In [16]:
class KMerEmbedding(nn.Module):
    def __init__(self, embedding_matrix):
        super().__init__()
        num_kmers, embedding_dim = embedding_matrix.shape
        self.embedding = nn.Embedding(num_kmers, embedding_dim)
        self.embedding.weight.data.copy_(embedding_matrix)
        self.embedding.weight.requires_grad = False  # Freeze embeddings

    def forward(self, kmer_indices):
        return self.embedding(kmer_indices)

In [17]:
kmer_embedding = KMerEmbedding(embedding_matrix)

In [18]:
# Test embedding
kmer_indices = torch.tensor([kmer_to_index[kmer] for kmer in seq_kmers[0]], dtype=torch.long)
kmer_embedding(kmer_indices)

tensor([[-1.3562, -1.1731, -3.3002,  ...,  3.7479, -1.8328, -0.3770],
        [-4.9456, -1.6875, -0.2400,  ..., -6.3022, -0.0254,  1.1563],
        [-0.3298, -1.6727, -3.6700,  ..., -4.2719, -0.5685, -2.8513],
        ...,
        [ 2.3797, -3.1560,  1.5470,  ...,  2.8186,  1.8949,  4.6125],
        [ 1.2922,  2.6484, -4.6266,  ...,  2.2645,  1.0404,  0.3241],
        [-1.3976, -5.5828, -0.2735,  ..., -0.4360,  0.9395,  0.5901]])

## Embeddings for all features