In [1]:
import sys
import os
import hydra
from omegaconf import OmegaConf

# add parent directory to path
sys.path.append(os.path.abspath(os.path.join('..')))

# initialize hydra

In [2]:
hydra.initialize(config_path="../config", version_base="1.1")

# Choose which config to load
config_name = "config"  # Change this to use a different config
print(f"Loading config: {config_name}")

# Load the config
cfg = hydra.compose(
    config_name=config_name, 
    overrides=["experiment=gpra_dna"]
)

# Display the loaded config
print(OmegaConf.to_yaml(cfg))

Loading config: config
dataset:
  _target_: datasets.gpra_dna_dataset.GPRADNADataset
  data_dir: /orcd/data/omarabu/001/njwfish/DistributionEmbeddings/data/gpra_processed
  set_size: ${experiment.set_size}
  num_quantiles: 100
  window_width: 3
  max_seq_length: 129
  encoder_tokenizer: dna
  hyena_tokenizer: char
  num_sets: null
  seed: ${seed}
encoder:
  _target_: encoder.dna_conv_encoder.DNAConvEncoder
  in_channels: 5
  hidden_channels: 64
  out_channels: 128
  hidden_dim: 256
  latent_dim: ${experiment.latent_dim}
  num_layers: 3
  kernel_size: 7
  seq_length: ${dataset.max_seq_length}
  pool_type: mean
  agg_type: mean
model:
  _target_: types.NoneType
generator:
  _target_: generator.hyenadna_generator.HyenaDNAGenerator
  latent_dim: ${experiment.latent_dim}
  condition_dim: 128
  d_model: 128
  n_layer: 6
  max_seq_len: ${dataset.max_seq_length}
  condition_method: prefix
  temperature: 1.0
optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: ${experiment.lr}
  beta

In [3]:
from torch.utils.data import DataLoader

dataset = hydra.utils.instantiate(cfg.dataset)
        


In [4]:
batch = dataset[0]

In [40]:
batch['raw_texts'].tolist()

['TGCATTTTTTTCACATCCTAGAGTTGATGGAATCCCCGGCCTTACAATCATGCGGTTATGCGTCCGGGTGTCGTGGGCTAATTCGTGACGCACGTTAGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCAGGTCCATGCGGATTCTGAACGTGGCATAGGTGGGAAGTGGGCATCTGGGGGATGTGGCCTCGGCACCAGAGAGTGGGGTGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCTTGGATTTATTAGGGGTTACGTGATTTTAGATATTCAACCATTTGAATAGCAGAGGAAAGAGTTTGCGTTTGGCAGGTAGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCTTATCCGCAATGGCTACAGGGTAGGGGGTTCGTAACGATTTGACGGATTTTTGGGCACGCATCTGTAAGTAAGTGTACCGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCAGCAAGCAGTTATAATTTGGATTTTGTATTTTTTCGGTGTTGAGGTTTTGTGGTGCAGTTATGTCATGGTGCTAATTAGGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCCGAGATATTATGTAAAAACCCCGTCGCTTTGTTAATATTTGAAAGCTACTACTATCGTGCCTTTTTAGAATGGTGGCTTCGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCTATCGATAGCGGTGTTAATACCGTTAATCTATGATTCAAAGGCACTAACGCGCTAAATGCTGAGATCGCGCATCCCGCGGGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCGGTTGAAGACGCTTTGGACACTTGGTCATTGTTTAATATTTAGGCGTCAGATTTAGTTTGGACGTAATGCCCCGGTTAATGGTTACGGCTGTT',
 'TGCATTTTTTTCACATCTACGGGTGTTTTTGGTATGTGCGTCCCTGTCTAAATACGGCGGCTTGAGTAAAGGTGTTGTTAC

In [7]:
dataset.data[0]['tokenized'].keys()

dict_keys(['encoder_inputs', 'hyena_input_ids', 'hyena_attention_mask'])

In [35]:
dataset.data[:3]# ['tokenized'] ['hyena_input_ids'][0]

[{'condition': 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gal_ALL',
  'center_quantile': 52,
  'window_width': 3,
  'sequences': array(['TGCATTTTTTTCACATCCTAGAGTTGATGGAATCCCCGGCCTTACAATCATGCGGTTATGCGTCCGGGTGTCGTGGGCTAATTCGTGACGCACGTTAGGTTACGGCTGTT',
         'TGCATTTTTTTCACATCAGGTCCATGCGGATTCTGAACGTGGCATAGGTGGGAAGTGGGCATCTGGGGGATGTGGCCTCGGCACCAGAGAGTGGGGTGGTTACGGCTGTT',
         'TGCATTTTTTTCACATCTTGGATTTATTAGGGGTTACGTGATTTTAGATATTCAACCATTTGAATAGCAGAGGAAAGAGTTTGCGTTTGGCAGGTAGGTTACGGCTGTT',
         'TGCATTTTTTTCACATCTTATCCGCAATGGCTACAGGGTAGGGGGTTCGTAACGATTTGACGGATTTTTGGGCACGCATCTGTAAGTAAGTGTACCGGTTACGGCTGTT',
         'TGCATTTTTTTCACATCAGCAAGCAGTTATAATTTGGATTTTGTATTTTTTCGGTGTTGAGGTTTTGTGGTGCAGTTATGTCATGGTGCTAATTAGGGTTACGGCTGTT',
         'TGCATTTTTTTCACATCCGAGATATTATGTAAAAACCCCGTCGCTTTGTTAATATTTGAAAGCTACTACTATCGTGCCTTTTTAGAATGGTGGCTTCGGTTACGGCTGTT',
         'TGCATTTTTTTCACATCTATCGATAGCGGTGTTAATACCGTTAATCTATGATTCAAAGGCACTAACGCGCTAAATGCTGAGATCGCGCATCCCGCGGGGTTACGGCTGTT',
   

In [31]:
torch.flip(dataset.data[:3]['tokenized']['hyena_input_ids'][0], [0])

TypeError: list indices must be integers or slices, not str

In [8]:
import torch
from datasets.hyena_tokenizer import CharacterTokenizer

hyena_tokenizer = CharacterTokenizer(characters={'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}, model_max_length=dataset.max_seq_length)
# print(hyena_tokenizer.decode()
# dataset.data[1]['sequences'][0][::-1]


In [11]:

# standard dataloader collate function
def collate_fn(batch):
    return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()}

dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [5]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [6]:
for batch in dataloader:
    print(batch)
    break

{'condition': ['GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gly_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gly_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gal_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gly_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gal_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gal_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gal_ALL', 'GSE104878_20161024_average_promoter_ELs_per_seq_3p1E7_Gal_ALL'], 'center_quantile': tensor([ 7, 57, 65, 50, 22, 33, 40, 83]), 'samples': {'encoder_inputs': tensor([[[[0., 0., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 0., 0.],
          ...,
          [0., 0., 0., 0., 1.],
          [0., 0., 0., 0., 1.],
          [0., 0., 0., 0., 1.]],

         [[0., 0., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 0., 0.],
          ...,
          [0., 0., 0., 0., 1.],
          [0., 0., 0.

In [11]:
import torch
hyena_tokenizer.decode(batch['samples']['hyena_input_ids'][0][0]), batch['raw_texts'][0][0][::-1]

('[SEP]TTGTCGGCATTGGATGTGGCATCTGTACGCGGGGGGGGTCGACTGGCATACATTCAAGTTGTGTTGCTGGGACTTAAATTTGGCCTGGGACGTCTACACTTTTTTTACGT[CLS][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]',
 'TTGTCGGCATTGGATGTGGCATCTGTACGCGGGGGGGGTCGACTGGCATACATTCAAGTTGTGTTGCTGGGACTTAAATTTGGCCTGGGACGTCTACACTTTTTTTACGT')

: 

In [13]:
# Create encoder
encoder = hydra.utils.instantiate(cfg.encoder)

In [6]:
# Create generator (with model already instantiated)
generator = hydra.utils.instantiate(cfg.generator)

In [7]:
# Get model parameters
model_parameters = list(encoder.parameters()) + list(generator.model.parameters())

# Create optimizer and scheduler
optimizer = hydra.utils.instantiate(cfg.optimizer)(params=model_parameters)
scheduler = hydra.utils.instantiate(cfg.scheduler)(optimizer=optimizer)

# Create trainer
trainer = hydra.utils.instantiate(cfg.training)

In [8]:
output_dir, stats = trainer.train(
    encoder=encoder,
    generator=generator,
    dataloader=dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    output_dir=os.path.abspath('../outputs'),
    config=cfg,
)

sampling timestep 400

OutOfMemoryError: CUDA out of memory. Tried to allocate 36.34 GiB. GPU 0 has a total capacity of 79.10 GiB of which 17.42 GiB is free. Including non-PyTorch memory, this process has 61.67 GiB memory in use. Of the allocated memory 59.98 GiB is allocated by PyTorch, and 994.54 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)