In [1]:
# Add confit to sys.path
import sys
sys.path.append("/ConFit")

In [2]:
import gc
import os
import warnings
import time
import yaml

import torch
from torch.utils.data import Dataset, DataLoader
import accelerate
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
import pandas as pd
import numpy as np


from confit.data_utils import Mutation_Set, split_train, sample_data

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Args
config_file = "/ConFit/config/training_config.yaml"
dataset = "GB1_Olson2014_ddg"
model_seed = 1
sample_seed = 0


In [4]:
#read in configa
with open(f'{config_file}', 'r', encoding='utf-8') as f:
    config = yaml.load(f.read(), Loader=yaml.FullLoader)

In [5]:
batch_size = int(int(config['batch_size'])/int(config['gpu_number']))

### ESM-1v is better for zero-shot predictions of mutations

# Train

In [6]:
from confit.train import train, evaluate

In [7]:
peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=int(config['lora_r']),
    lora_alpha=int(config['lora_alpha']),
    lora_dropout=float(config['lora_dropout']),
    target_modules=["query", "value"]
)

In [8]:


def training_loop(mixed_precision="fp16", seed: int = 42, batch_size: int = 64):
    set_seed(seed)
    # Initialize accelerator
    accelerator = Accelerator(mixed_precision=mixed_precision)

    # Model
    if config['model'] == 'ESM-1v':
        basemodel = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{model_seed}')
        model_reg = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{model_seed}')
        tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_{model_seed}')

    for pm in model_reg.parameters():
        pm.requires_grad = False
    model_reg.eval()    #regularization model
    model = get_peft_model(basemodel, peft_config)
    
    # create optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=float(config['ini_lr']))
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2*int(config['max_epochs']), eta_min=float(config['min_lr']))
    if os.environ.get("ACCELERATE_USE_FSDP", None) is not None:
        accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
    
    # Prepare model
    model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
    model_reg = accelerator.prepare(model_reg)

    # Load data
    if accelerator.is_main_process:
        sample_data(dataset, sample_seed, int(config['shot']))
        split_train(dataset)

    with accelerator.main_process_first():
        train_csv = pd.DataFrame(None)
        test_csv = pd.read_csv(f'../data/{dataset}/test.csv')
        for i in range(1, 6):
            if i == model_seed:
                val_csv = pd.read_csv(f'../data/{dataset}/train_{i}.csv')   #using 1/5 train data as validation set
            temp_csv = pd.read_csv(f'../data/{dataset}/train_{i}.csv')
            train_csv = pd.concat([train_csv, temp_csv], axis=0)
    
    #create dataset and dataloader
    trainset = Mutation_Set(data=train_csv, fname=dataset, tokenizer=tokenizer)
    testset = Mutation_Set(data=test_csv, fname=dataset,  tokenizer=tokenizer)
    valset = Mutation_Set(data=val_csv, fname=dataset,  tokenizer=tokenizer)

    with accelerator.main_process_first():
        trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=trainset.collate_fn, shuffle=True)
        testloader = DataLoader(testset, batch_size=2, collate_fn=testset.collate_fn)
        valloader = DataLoader(valset, batch_size=2, collate_fn=testset.collate_fn)

    # Train
    best_sr = -np.inf
    endure = 0
    best_epoch = 0
    for epoch in range(int(config['max_epochs'])):
        loss = train(model, model_reg, trainloader, optimizer, tokenizer, float(config['lambda_reg']))
        accelerator.print(f'========epoch{epoch}; training loss :{loss}=================')
        sr = evaluate(model, valloader, tokenizer, accelerator)
        accelerator.print(f'========epoch{epoch}; val spearman correlation :{sr}=================')
        scheduler.step()
        if best_sr > sr:
            endure += 1
        else:
            endure = 0
            best_sr = sr
            best_epoch = epoch

            if not os.path.isdir(f'checkpoint/{dataset}'):
                if accelerator.is_main_process:
                    os.makedirs(f'checkpoint/{dataset}')
            save_path = os.path.join('checkpoint', f'{dataset}',
                                        f'seed{args.model_seed}')
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(save_path)
        if sr == 1.0:
            accelerator.print(f'========early stop at epoch{epoch}!============')
            break
        if endure > int(config['endure_time']):
            accelerator.print(f'========early stop at epoch{epoch}!============')
            break

In [None]:
# def training_loop(mixed_precision="fp16", seed: int = 42, batch_size: int = 64):
#     set_seed(seed)
#     # Initialize accelerator
#     accelerator = Accelerator(mixed_precision=mixed_precision)
#     # Build dataloaders
#     train_dataloader, eval_dataloader = get_dataloaders(batch_size)

#     # Instantiate the model (you build the model here so that the seed also controls new weight initaliziations)
#     model = create_model("resnet50d", pretrained=True, num_classes=len(label_to_id))

#     # Freeze the base model
#     for param in model.parameters():
#         param.requires_grad = False
#     for param in model.get_classifier().parameters():
#         param.requires_grad = True

#     # You can normalize the batches of images to be a bit faster
#     mean = torch.tensor(model.default_cfg["mean"])[None, :, None, None]
#     std = torch.tensor(model.default_cfg["std"])[None, :, None, None]

#     # To make these constants available on the active device, set it to the accelerator device
#     mean = mean.to(accelerator.device)
#     std = std.to(accelerator.device)

#     # Instantiate the optimizer
#     optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-2 / 25)

#     # Instantiate the learning rate scheduler
#     lr_scheduler = OneCycleLR(optimizer=optimizer, max_lr=3e-2, epochs=5, steps_per_epoch=len(train_dataloader))

#     # Prepare everything
#     # There is no specific order to remember, you just need to unpack the objects in the same order you gave them to the
#     # prepare method.
#     model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
#         model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
#     )

#     # Now you train the model
#     for epoch in range(5):
#         model.train()
#         for batch in train_dataloader:
#             inputs = (batch["image"] - mean) / std
#             outputs = model(inputs)
#             loss = torch.nn.functional.cross_entropy(outputs, batch["label"])
#             accelerator.backward(loss)
#             optimizer.step()
#             lr_scheduler.step()
#             optimizer.zero_grad()

#         model.eval()
#         accurate = 0
#         num_elems = 0
#         for batch in eval_dataloader:
#             inputs = (batch["image"] - mean) / std
#             with torch.no_grad():
#                 outputs = model(inputs)
#             predictions = outputs.argmax(dim=-1)
#             accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch["label"])
#             num_elems += accurate_preds.shape[0]
#             accurate += accurate_preds.long().sum()

#         eval_metric = accurate.item() / num_elems
#         # Use accelerator.print to print only on the main process.
#         accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")

In [9]:
from accelerate import notebook_launcher

In [10]:
args = ("fp16", 42, 16)
notebook_launcher(training_loop, args, num_processes=4)

Launching training on 2 GPUs.


Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm1v_t33_650M_UR90S_1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm1v_t33_650M_UR90S_1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm1v_t33_650M_UR90S_1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmFor

RuntimeError: An issue was found when launching the training: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/launch.py", line 570, in __call__
    self.launcher(*args)
  File "/tmp/ipykernel_65972/2310528820.py", line 24, in training_loop
    model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1228, in prepare
    result = tuple(
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1229, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1105, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/accelerator.py", line 1356, in prepare_model
    model = torch.nn.parallel.DistributedDataParallel(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/distributed.py", line 798, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/utils.py", line 263, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1691, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.19.3
ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
Last error:
Error while creating shared memory segment /dev/shm/nccl-iEhy0K (size 9637888)
