In [1]:
import random
import importlib
from datasets import Dataset
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import prokbert

from prokbert.ProkBERTDataCollator import DataCollatorForGenomeNetwork
from prokbert.model.genome_network import modeling_genome_network, configuration_genome_network

In [4]:
# importlib.reload(prokbert)

In [5]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

## Dataset

In [6]:
model_name = 'neuralbioinfo/prokbert-mini'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

In [8]:
print("PAD token:", tokenizer.pad_token, "->", tokenizer.pad_token_id)
print("CLS token:", tokenizer.cls_token, "->", tokenizer.cls_token_id)
print("SEP token:", tokenizer.sep_token, "->", tokenizer.sep_token_id)
print("MASK token:", tokenizer.mask_token, "->", tokenizer.mask_token_id)
print("UNK token:", tokenizer.unk_token, "->", tokenizer.unk_token_id)
print("VOCAB SIZE:", tokenizer.vocab_size, len(tokenizer))

PAD token: [PAD] -> 0
CLS token: [CLS] -> 2
SEP token: [SEP] -> 3
MASK token: [MASK] -> 4
UNK token: [UNK] -> 1
VOCAB SIZE: 4101 4101


### Random Genom Dataset

In [9]:
def random_gene_sequence(low = 10, high = 20):
    n = random.randint(low, high)
    return "".join(random.choice("ACGT") for _ in range(n))

In [10]:
def create_random_genome_dataset(
        dataset_num=100,
        gene_per_genom_low=2,
        gene_per_genom_high=5,
        gene_seq_low=10,
        gene_seq_high=20
):
    genoms = {"genom": [], "gene_nums": [], "sequences": []}
    for i in range(1, dataset_num + 1):

        gene_nums = []
        gene_sequences = []

        n = random.randint(gene_per_genom_low, gene_per_genom_high)
        for j in range(n):

            gene_sequence = random_gene_sequence(gene_seq_low, gene_seq_high)
            gene_nums.append(j)
            gene_sequences.append(gene_sequence)

        genoms["genom"].append(i)
        genoms["gene_nums"].append(gene_nums)
        genoms["sequences"].append(gene_sequences)

    return genoms


In [11]:
genoms = create_random_genome_dataset(
    dataset_num=100,
    gene_per_genom_low=4,
    gene_per_genom_high=7,
    gene_seq_low=5,
    gene_seq_high=10
)

In [12]:
dataset = Dataset.from_dict(genoms)

In [14]:
for i, row in enumerate(dataset):
    print(f"Row {i}: {row}")
    if i>10:
        break

Row 0: {'genom': 1, 'gene_nums': [0, 1, 2, 3], 'sequences': ['GCCCA', 'ATAAACCACT', 'TGACTG', 'CCGAATA']}
Row 1: {'genom': 2, 'gene_nums': [0, 1, 2, 3, 4, 5], 'sequences': ['GATATAG', 'GCAACGACAT', 'TGCGGCG', 'ACCCTTGCGA', 'AGTGAC', 'GCTTTCGCC']}
Row 2: {'genom': 3, 'gene_nums': [0, 1, 2, 3, 4, 5], 'sequences': ['TTGCCTAAAC', 'CTATTTGAAG', 'GAGTCTAGCA', 'GCCGCAGTAA', 'GCACAAT', 'CCTCG']}
Row 3: {'genom': 4, 'gene_nums': [0, 1, 2, 3, 4, 5, 6], 'sequences': ['CGTGTT', 'CCAGA', 'CCAAACAAG', 'CGTCC', 'TCTTCAATGT', 'TAAATGAC', 'CTCTCG']}
Row 4: {'genom': 5, 'gene_nums': [0, 1, 2, 3, 4, 5, 6], 'sequences': ['ATAAAA', 'CTTTCT', 'CTATG', 'GTTCCGCA', 'AGAATCAAC', 'ACTAC', 'AATGGCGCG']}
Row 5: {'genom': 6, 'gene_nums': [0, 1, 2, 3, 4, 5, 6], 'sequences': ['GTGAAT', 'AACGCGACG', 'CTGAGAC', 'AACGGCG', 'GTGAAT', 'AAGCGCT', 'TAAACAGCT']}
Row 6: {'genom': 7, 'gene_nums': [0, 1, 2, 3, 4], 'sequences': ['GGAGC', 'CAGTCCCCTA', 'GTCGCA', 'ATCCTGGC', 'ACTGGA']}
Row 7: {'genom': 8, 'gene_nums': [0, 1, 2, 3

In [15]:
tokenized_dataset = [tokenizer(genom["sequences"], padding=True, return_tensors="pt") for genom in dataset.to_list()]

In [16]:
dataset.remove_columns(["sequences", "gene_nums", "genom"])

Dataset({
    features: [],
    num_rows: 0
})

In [17]:
tokenized_dataset[0]

{'input_ids': tensor([[   2,    3,    0,    0,    0,    0,    0],
        [   2,  774, 3082,   25,   86,  332,    3],
        [   2, 3619,    3,    0,    0,    0,    0],
        [   2, 1416, 1553,    3,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0]])}

In [18]:
data_collator = DataCollatorForGenomeNetwork(
    tokenizer,
    mlm=True,
    mlm_probability=0.7,
    mask_replace_prob=0.6,
    random_replace_prob=0.4
)

In [19]:
loader = DataLoader(tokenized_dataset, batch_size=2, collate_fn=data_collator)

In [20]:
for batch in loader:
    print({k: v.shape for k, v in batch.items()})
    break

{'input_ids': torch.Size([2, 6, 7]), 'attention_mask': torch.Size([2, 6, 7]), 'token_type_ids': torch.Size([2, 6, 7]), 'labels': torch.Size([2, 6, 7]), 'labels_mask': torch.Size([2, 6])}


In [21]:
batch["input_ids"]

tensor([[[   2,    3,    0,    0,    0,    0,    0],
         [   2,  774, 3082,   25,   86,  332,    3],
         [   2, 2088,    3,    0,    0,    0,    0],
         [   2,    4,    4,    3,    0,    0,    0],
         [   2,    3,    0,    0,    0,    0,    0],
         [   2,    3,    0,    0,    0,    0,    0]],

        [[   2, 2257,  823,    3,    0,    0,    0],
         [   2,    4,    4,    4,    4,    4,    3],
         [   2, 3694, 2475,    3,    0,    0,    0],
         [   2, 2943, 1649, 3883, 3582, 2055,    3],
         [   2,  742,    3,    0,    0,    0,    0],
         [   2,    4,    4,    4,    4,    3,    0]]])

In [22]:
batch["attention_mask"]

tensor([[[1, 1, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0, 0, 0]],

        [[1, 1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 0]]])

In [23]:
batch["labels"]

tensor([[[-100, -100, -100, -100, -100, -100, -100],
         [-100, -100, -100, -100, -100, -100, -100],
         [   2, 3619,    3,    0,    0,    0,    0],
         [   2, 1416, 1553,    3,    0,    0,    0],
         [-100, -100, -100, -100, -100, -100, -100],
         [-100, -100, -100, -100, -100, -100, -100]],

        [[-100, -100, -100, -100, -100, -100, -100],
         [   2, 2315, 1053,  102,  393, 1560,    3],
         [-100, -100, -100, -100, -100, -100, -100],
         [   2,  356, 1411, 1534, 2027, 3997,    3],
         [-100, -100, -100, -100, -100, -100, -100],
         [   2, 2562, 2043, 4062, 3946,    3,    0]]])

In [24]:
batch["labels_mask"]

tensor([[False, False,  True,  True, False, False],
        [False,  True, False,  True, False,  True]])

## Genome Network

In [25]:
input_ids=batch["input_ids"]
attention_mask=batch["attention_mask"]
token_type_ids=batch["token_type_ids"]

In [26]:
prokbert_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

for param in prokbert_model.parameters():
    param.requires_grad = False

Some weights of ProkBertModel were not initialized from the model checkpoint at neuralbioinfo/prokbert-mini and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [33]:
embedding_outputs = [
    prokbert_model(
        input_ids=input_id,
        attention_mask=attn_mask,
        token_type_ids=token_type_id
    ).pooler_output.detach()
    for input_id, attn_mask, token_type_id in zip(input_ids, attention_mask, token_type_ids)
]
print([embedding_output.shape for embedding_output in embedding_outputs])

[torch.Size([6, 384]), torch.Size([6, 384])]


In [34]:
inputs_embeds = torch.stack(embedding_outputs, dim=0)

In [36]:
config = configuration_genome_network.GenomeNetworkConfig()

In [37]:
model = modeling_genome_network.GenomeNetwork(config)

In [38]:
outputs = model(inputs_embeds)

In [39]:
outputs.last_hidden_state.shape, outputs[0].shape

(torch.Size([2, 6, 384]), torch.Size([2, 6, 384]))

## Masked Language Modeling

In [40]:
genome_network_config = configuration_genome_network.GenomeNetworkConfig()

In [41]:
mlm_model = modeling_genome_network.GenomeNetworkForMaskedLM(genome_network_config, embedding_model=prokbert_model)

In [42]:
mlm_output = mlm_model(**batch)

In [43]:
mlm_output

MaskedLMOutput(loss=tensor(0.7463, grad_fn=<MseLossBackward0>), logits=None, hidden_states=None, attentions=None)

## Training

In [44]:
training_args = TrainingArguments(
    output_dir="genome_network_mlm_model",
    logging_strategy="epoch",
    eval_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=10,
    weight_decay=0.01,
    push_to_hub=False,
)

In [45]:
trainer = Trainer(
    model=mlm_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
    data_collator=data_collator
)

In [46]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,0.7356,0.729276
2,0.7276,0.725528
3,0.7271,0.724808
4,0.7258,0.724417
5,0.7192,0.724007
6,0.7228,0.723323
7,0.7267,0.722874
8,0.7226,0.723181
9,0.7245,0.722394
10,0.7221,0.723476




TrainOutput(global_step=40, training_loss=0.7253991663455963, metrics={'train_runtime': 22.6956, 'train_samples_per_second': 44.061, 'train_steps_per_second': 1.762, 'total_flos': 8630394974208.0, 'train_loss': 0.7253991663455963, 'epoch': 10.0})