In [1]:
import numpy as np
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,AutoModel,TrainingArguments, Trainer,DataCollatorWithPadding, EarlyStoppingCallback, get_constant_schedule_with_warmup
import transformers
from datasets import load_dataset, Dataset, load_from_disk
from stripedhyena.tokenizer import CharLevelTokenizer
from stripedhyena.model import StripedHyena
from stripedhyena.layers import VocabParallelEmbedding
from stripedhyena.utils import dotdict
import yaml
import multiprocessing
from tqdm import tqdm
from accelerate import Accelerator, DistributedType
from torch.utils.data import DataLoader
import os

In [2]:
data_files = './processed_data'
cpu_cnt = multiprocessing.cpu_count()
max_length = 300
model_name = './evo-1-8k-base'


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')


print(torch.version.cuda)

Using device: cuda

NVIDIA A100-SXM4-40GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB
11.8




In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True,padding='max_length',bos_token='\x00')
tokenizer.pad_token = '\x01'
data_collator = DataCollatorWithPadding(tokenizer=tokenizer,max_length=max_length,padding='max_length')
#Evo utilizes a custom tokenizer named "CharLevelTokenizer", huggingface version does not have pad_token defined. Thus pad_token is manually defined to have same value as the pad_token utilizes in CharLevelTokenizer

def prefix_function(d):
    d['dna_seq'] = tokenizer.bos_token + d['dna_seq']
    return d

def tokenize_function(d):
    return tokenizer(d['dna_seq'], padding='max_length',max_length = max_length)

In [5]:
def prepare_dsets(data_path,batch_size):
    dsets = load_from_disk(data_path)
    dsets = dsets.map(prefix_function, num_proc=cpu_cnt)
    dsets = dsets.map(tokenize_function, batched=True, num_proc=cpu_cnt)
    useful_col = ['input_ids', 'attention_mask', 'labels']
    dsets = dsets.remove_columns([col for col in dsets['train'].column_names if col not in useful_col])
    train_loader = DataLoader(
        dsets['train'],
        collate_fn=data_collator,
        shuffle=True,
        batch_size=batch_size
        )
    val_loader = DataLoader(
            dsets['val'],
            collate_fn=data_collator,
            shuffle=True,
            batch_size=batch_size
        )
    test_loader = DataLoader(
            dsets['test'],
            collate_fn=data_collator,
            shuffle=True,
            batch_size=batch_size
        )
    return train_loader,val_loader,test_loader, len(dsets['train'])


In [14]:
def train_model(model,OUT_DIR,data_path,batch_size = 128,warmup_steps = 10,epochs = 3,learning_rate=2e-5,checkpointing_steps=1000):

    train_loader,val_loader,test_loader, epoch_size = prepare_dsets(data_path,batch_size)
    os.makedirs(OUT_DIR, exist_ok=True)
    
    
    accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="bf16"
    )
    
    accelerator.print(f"Total GPUS: {accelerator.num_processes}")
    
    
    num_training_steps = epoch_size * epochs // batch_size
    num_epoch_steps = num_training_steps // epochs

    optim = torch.optim.AdamW(list(model.parameters()), lr=learning_rate)
    scheduler = transformers.get_cosine_schedule_with_warmup(optim, num_warmup_steps=warmup_steps,num_training_steps = num_training_steps)
                                                  
    model,optim, train_loader, scheduler= accelerator.prepare(model,optim, train_loader, scheduler)
    
    accelerator.register_for_checkpointing(scheduler)
    
    progress_bar = tqdm(range(num_training_steps))
    completed_steps = 0
    

    loss_fn = torch.nn.MSELoss()

    

    stp_lst = []
    for i in os.listdir(OUT_DIR):
        if i[:4]=='step':
            stp_lst.append(int(i[5:]))

    if stp_lst != []:
        resume_from_checkpoint = True
        resume_step = max(stp_lst)
        latest = 'step_' + str(resume_step)
        checkpoint_dir = os.path.join(OUT_DIR,latest)
        path = os.path.basename(checkpoint_dir)
        accelerator.print(
                f"Resuming from checkpoint {latest}")
        accelerator.load_state(checkpoint_dir, strict=False)
    else: 
        resume_from_checkpoint = False
    
    if resume_from_checkpoint and resume_step is not None:
        train_loader_skipped = accelerator.skip_first_batches(
            train_loader, resume_step % num_epoch_steps)
        completed_steps += resume_step
        progress_bar.update(resume_step)
        accelerator.print(f"Resuming training from step {resume_step}")
        epoch = resume_step // num_epoch_steps
        skip = True
    else:
        epoch = 0
        skip = False

    
    log_loss = os.path.join(OUT_DIR,'train_log')
    #loss_file = open(log_loss, "w", encoding="utf-8") if accelerator.is_main_process else None
    loss_file = open(log_loss, "a" if resume_from_checkpoint else "w", encoding="utf-8") if accelerator.is_main_process else None
    
    print(f'Starting epoch: {epoch}, Ending Epoch: {epochs}, Total training steps: {num_training_steps}')
    
    for epoch in range(epoch,epochs):
        model.train()
        if skip:
            for step, batch in enumerate(train_loader_skipped):
                loss_log = None
                skip = False
                with accelerator.accumulate(model):
                    outputs = model(**batch)
                    loss = loss_fn(outputs.squeeze(),batch['labels'].squeeze())
                    accelerator.backward(loss)
                    
                    if accelerator.sync_gradients:
                        last_lr = scheduler.get_last_lr()
                        if isinstance(last_lr, list):
                            last_lr = last_lr[0]
                        loss_log = {
                            "loss": loss.item(),
                            "epoch": completed_steps / num_epoch_steps,
                            "learning_rate": last_lr
                        }
                        accelerator.log(loss_log, step=completed_steps)
                        if loss_file is not None:
                            loss_file.write(f"{loss_log['loss']},")
                            loss_file.flush()
        
                        
                    optim.step()
                    scheduler.step()
                    optim.zero_grad()
                
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    if loss_log is not None:
                        progress_bar.set_postfix(loss_log)
                    completed_steps += 1
                    
        
                    if completed_steps > 0:
                        if completed_steps % checkpointing_steps == 0:
                            output_dir = f"step_{completed_steps}"
                            if OUT_DIR is not None:
                                output_dir = os.path.join(
                                    OUT_DIR, output_dir)
                            accelerator.save_state(output_dir)
    
            
                if completed_steps >= num_training_steps:
                    print(completed_steps)
                    print(num_training_steps)
                    print('broke')
                    break
        else: 
            for step, batch in enumerate(train_loader):
                loss_log = None
                with accelerator.accumulate(model):
                    outputs = model(**batch)
                    loss = loss_fn(outputs.squeeze(),batch['labels'].squeeze())
                    accelerator.backward(loss)
                    
                    if accelerator.sync_gradients:
                        last_lr = scheduler.get_last_lr()
                        if isinstance(last_lr, list):
                            last_lr = last_lr[0]
                        loss_log = {
                            "loss": loss.item(),
                            "epoch": completed_steps / num_epoch_steps,
                            "learning_rate": last_lr
                        }
                        accelerator.log(loss_log, step=completed_steps)
                        if loss_file is not None:
                            loss_file.write(f"{loss_log['loss']},")
                            loss_file.flush()
        
                        
                    optim.step()
                    scheduler.step()
                    optim.zero_grad()
                
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    if loss_log is not None:
                        progress_bar.set_postfix(loss_log)
                    completed_steps += 1
                    
        
                    if completed_steps > 0:
                        if completed_steps % checkpointing_steps == 0:
                            output_dir = f"step_{completed_steps}"
                            if OUT_DIR is not None:
                                output_dir = os.path.join(
                                    OUT_DIR, output_dir)
                            accelerator.save_state(output_dir)
    
            
                if completed_steps >= num_training_steps:
                    print(completed_steps)
                    print(num_training_steps)
                    print('broke')
                    break
                    
        
            
        if completed_steps >= num_training_steps:
            print(completed_steps)
            print(num_training_steps)
            print('broke')
            break
    
    accelerator.print("Training Finished")
    accelerator.end_training()

    if OUT_DIR is not None:
            accelerator.print(f"Saving model to {OUT_DIR}")
    
            accelerator.wait_for_everyone()
    
            if accelerator.distributed_type == DistributedType.FSDP:
                full_state_dict_config = FullStateDictConfig(
                    offload_to_cpu=True, rank0_only=True)
                with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
                    state_dict = accelerator.get_state_dict(model, unwrap=False)
            else:
                state_dict = accelerator.get_state_dict(model)
    
            torch.save(state_dict,os.path.join(OUT_DIR,"final.pickle"))
            accelerator.print("Saving Finished")
    model, optim, train_loader, scheduler = accelerator.free_memory(model,optim, train_loader, scheduler)

## Models

In [7]:
model_config = AutoConfig.from_pretrained(model_name,trust_remote_code=True)
model_config.use_cache = False
model_og = AutoModelForCausalLM.from_pretrained(
    model_name,
    config=model_config,
    trust_remote_code=True,
)

config_path = './custom_config.yaml'
state_dict = model_og.backbone.state_dict()
config = yaml.safe_load(open(config_path, 'rb').read())
global_config = dotdict(config, Loader=yaml.FullLoader)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
'''
class MyEmbed(VocabParallelEmbedding):
    def __init__(self, config):
        vocab_size, process_group, padding_idx = (
            config.vocab_size,
            config.get("process_group", None),
            config.get("padding_idx", None),
        )
        self.process_group = process_group
        if process_group is not None:
            world_size = torch.distributed.get_world_size(process_group)
            if vocab_size % world_size != 0:
                raise ValueError(f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})")
            if world_size > 1 and padding_idx is not None:
                raise RuntimeError("ParallelEmbedding does not support padding_idx")
        else:
            world_size = 1
        super().__init__(
            vocab_size // world_size,
            embedding_dim=config.hidden_size,
            padding_idx=padding_idx,
        )
    def embed(self, input: Tensor) -> Tensor:
        if self.process_group is None:
            return self.forward(input)
        else:
            rank = torch.distributed.get_rank(self.process_group)
            vocab_size = self.num_embeddings
            vocab_start_index, vocab_end_index = (
                rank * vocab_size,
                (rank + 1) * vocab_size,
            )
            # Create a mask of valid vocab ids (1 means it needs to be masked).
            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
            input = input - vocab_start_index
            input[input_ids_mask] = 0
            embeddings = self.forward(input)
            embeddings[input_ids_mask] = 0.0
            # Reduce to the global process group
            torch.distributed.all_reduce(embeddings, group=self.process_group)
            return embeddings

'''

'\nclass MyEmbed(VocabParallelEmbedding):\n    def __init__(self, config):\n        vocab_size, process_group, padding_idx = (\n            config.vocab_size,\n            config.get("process_group", None),\n            config.get("padding_idx", None),\n        )\n        self.process_group = process_group\n        if process_group is not None:\n            world_size = torch.distributed.get_world_size(process_group)\n            if vocab_size % world_size != 0:\n                raise ValueError(f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})")\n            if world_size > 1 and padding_idx is not None:\n                raise RuntimeError("ParallelEmbedding does not support padding_idx")\n        else:\n            world_size = 1\n        super().__init__(\n            vocab_size // world_size,\n            embedding_dim=config.hidden_size,\n            padding_idx=padding_idx,\n        )\n    def embed(self, input: Tensor) -> Tensor:\n        if self.pro

In [9]:
class MyHyena(StripedHyena):
    def __init_(self,config):
        super().__init__()
        self.config = config
        self.embedding_layer = VocabParallelEmbedding(config)
        self.norm = RMSNorm(config) if config.get("final_norm", True) else None
        #self.unembed = self.embedding_layer if config.tie_embeddings else VocabParallelEmbedding(config)

        if config.get("use_flashfft", "True"):
            try:
                from flashfftconv import FlashFFTConv

                self.flash_fft = FlashFFTConv(config.seqlen, dtype=torch.bfloat16)
            except ImportError:
                "flashfftconv not installed"
        else:
            self.flash_fft = None

        self.blocks = nn.ModuleList(
            get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
        )

    def forward(self, x, inference_params_dict=None, padding_mask=None):
        L = x.shape[1]
        x = self.embedding_layer.embed(x)
        if inference_params_dict is not None:
            x, inference_params_dict_out = self.stateful_forward(
                x,
                inference_params_dict=inference_params_dict,
            )
        else:
            x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)

        x = self.norm(x)
        #Removed unembedding
        #x = self.unembed.unembed(x)
        return x, inference_params_dict_out

In [10]:
custom_hyena = MyHyena(global_config)
custom_hyena.load_state_dict(state_dict,strict=True)


for p in custom_hyena.parameters():
    p.requires_grad = False

In [11]:
class EvoForRegression(nn.Module):
  def __init__(self): 
    super(EvoForRegression,self).__init__() 

    self.model = custom_hyena
    self.dropout = nn.Dropout(0.1) 
    #First_Token representation

    self.lin = nn.Linear(model_config.hidden_size,model_config.hidden_size) 
    self.out = nn.Linear(model_config.hidden_size,1) 
      

  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the hyena
    outputs = self.model(x=input_ids, padding_mask=attention_mask)[0][:,0,:]
    #Add custom layers
    outputs = self.dropout(outputs)

    outputs = self.lin(outputs)
      
    outputs = nn.ReLU()(outputs)  # (bs, dim)
    outputs = self.dropout(outputs)  # (bs, dim)
    out = self.out(outputs)

    return out


In [12]:
model = EvoForRegression()

In [15]:
# Examine batchsize effect on model
batchsizes = [64,128]
OUT_DIR='./loop/'
data_path = data_files

val_loss_lst = []
for i in batchsizes:
    outdir = OUT_DIR + str(i)
    del model
    model = EvoForRegression()
    train_model(model,outdir,data_path,batch_size = i,warmup_steps = 10,epochs = 1,learning_rate=2e-5,checkpointing_steps=1000)

Total GPUS: 1



  0%|          | 0/8441 [00:00<?, ?it/s][A

Resuming from checkpoint step_8000



 95%|█████████▍| 8000/8441 [01:46<00:05, 75.19it/s][A

Resuming training from step 8000
Starting epoch: 0, Ending Epoch: 1, Total training steps: 8441



 95%|█████████▍| 8001/8441 [01:49<00:05, 75.19it/s, loss=2.83, epoch=0.948, learning_rate=1.35e-7][A
 95%|█████████▍| 8002/8441 [01:51<00:05, 75.19it/s, loss=2.44, epoch=0.948, learning_rate=1.34e-7][A
 95%|█████████▍| 8003/8441 [01:54<00:05, 75.19it/s, loss=2.94, epoch=0.948, learning_rate=1.33e-7][A
 95%|█████████▍| 8004/8441 [01:56<00:05, 75.19it/s, loss=2.48, epoch=0.948, learning_rate=1.33e-7][A
 95%|█████████▍| 8005/8441 [01:58<00:05, 75.19it/s, loss=2.76, epoch=0.948, learning_rate=1.32e-7][A
 95%|█████████▍| 8006/8441 [02:00<00:05, 75.19it/s, loss=3.18, epoch=0.948, learning_rate=1.32e-7][A
 95%|█████████▍| 8007/8441 [02:03<00:05, 75.19it/s, loss=2.91, epoch=0.948, learning_rate=1.31e-7][A
 95%|█████████▍| 8007/8441 [02:05<00:05, 75.19it/s, loss=2.91, epoch=0.948, learning_rate=1.31e-7][A
 95%|█████████▍| 8008/8441 [02:05<00:07, 60.09it/s, loss=2.91, epoch=0.948, learning_rate=1.31e-7][A
 95%|█████████▍| 8008/8441 [02:05<00:07, 60.09it/s, loss=2.76, epoch=0.949, learn

8441
8441
broke
8441
8441
broke
Training Finished
Saving model to ./loop/64


100%|██████████| 8441/8441 [16:56<00:00,  8.31it/s, loss=2.48, epoch=1, learning_rate=6.94e-13]

Saving Finished





Total GPUS: 1



  0%|          | 0/4220 [00:00<?, ?it/s][A

Starting epoch: 0, Ending Epoch: 1, Total training steps: 4220



  0%|          | 1/4220 [00:04<4:44:44,  4.05s/it][A
  0%|          | 1/4220 [00:04<4:44:44,  4.05s/it, loss=6.57, epoch=0, learning_rate=0][A
  0%|          | 2/4220 [00:07<4:34:08,  3.90s/it, loss=6.57, epoch=0, learning_rate=0][A
  0%|          | 2/4220 [00:07<4:34:08,  3.90s/it, loss=6.43, epoch=0.000237, learning_rate=2e-6][A
  0%|          | 3/4220 [00:11<4:31:03,  3.86s/it, loss=6.43, epoch=0.000237, learning_rate=2e-6][A
  0%|          | 3/4220 [00:11<4:31:03,  3.86s/it, loss=5.55, epoch=0.000474, learning_rate=4e-6][A
  0%|          | 4/4220 [00:15<4:29:23,  3.83s/it, loss=5.55, epoch=0.000474, learning_rate=4e-6][A
  0%|          | 4/4220 [00:15<4:29:23,  3.83s/it, loss=6.98, epoch=0.000711, learning_rate=6e-6][A
  0%|          | 5/4220 [00:19<4:28:21,  3.82s/it, loss=6.98, epoch=0.000711, learning_rate=6e-6][A
  0%|          | 5/4220 [00:19<4:28:21,  3.82s/it, loss=5.5, epoch=0.000948, learning_rate=8e-6] [A
  0%|          | 6/4220 [00:23<4:27:39,  3.81s/it, loss=

4220
4220
broke
4220
4220
broke
Training Finished
Saving model to ./loop/128


100%|██████████| 4220/4220 [4:31:13<00:00,  3.86s/it, loss=2.73, epoch=1, learning_rate=2.78e-12]

Saving Finished



