In [1]:
import os
import argparse
import json
import numpy as np
import torch
import pytorch_lightning as pl 
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from inst_follow.models.clm import CLM
from mttl.callbacks import ProgressCallback
from mttl.datamodule.alpaca_data_module import AlpacaDataModule
from mttl.datamodule.longform_data_module import LongFormDataModule
from mttl.datamodule.wizzard_data_module import WizzardDataModule
from mttl.models.encoder_decoder import EncoderDecoder
from mttl.models.t0_encoder_decoder import T0EncoderDecoder
from mttl.config import Config as MTTLConfig
from mttl.models.monitors import get_monitors
from mttl.utils import get_mlf_logger
from mttl.models.modify_model import modify_transformer
from transformers import LlamaForCausalLM, LlamaTokenizer
from mttl.dataloader.data_utils import ExampleInfo
from mttl.utils import get_ni_tasks_from_file, trim_batch, hash_example
from typing import List

  from .autonotebook import tqdm as notebook_tqdm


Setting ds_accelerator to cuda (auto detect)


In [2]:

class Config(MTTLConfig):
    def __init__(self, **kwargs):
        self.rank = 1  
        self.prune_unused_loras = True
        self.init_b_random = False
        self.lora_dropout = 0
        self.lora_alpha = 16
        self.load_in_8bit = False
        self.micro_batch_size = 4
        self.train_on_inputs = False
        self.padding_side = "right"
        self.adapter_modules = None
        self.poly_selector_use_distances = False
        self.adapter_layers = 0  # llama adapter
        self.adapter_len = 0  # llama adapter
        super().__init__(**kwargs)
        # to reproduce setup in https://github.com/daanelson/alpaca-lora
        self.gradient_accumulation_steps = (
            self.train_batch_size // self.micro_batch_size
        )
        self.train_batch_size = self.micro_batch_size


def parse_config(extra_kwargs=None, raise_error=True, parent=None, return_parser=False, c=None):
    import itertools

    # dont do it if called from jupyter notebook
    if c is None:
        parser = (
            argparse.ArgumentParser(parents=[parent])
            if parent
            else argparse.ArgumentParser()
        )
        parser.add_argument("-c", "--config_files", required=False)
        parser.add_argument("-k", "--kwargs", nargs="*", action="append")
        args = parser.parse_args()
    else:
        args = argparse.Namespace()
        args.kwargs = None
        args.config_files = c
    kwargs = {}
    if args.kwargs:
        kwargs_opts = list(itertools.chain(*args.kwargs))
        for value in kwargs_opts:
            key, _, value = value.partition("=")
            kwargs[key] = value

    args.kwargs = kwargs
    if extra_kwargs:  
        args.kwargs.update(extra_kwargs)

    config = Config(
        filenames=args.config_files, kwargs=args.kwargs, raise_error=raise_error
    )

    print(config.to_json())
    if return_parser:
        return config, args
    return config



Bad pipe message: %s [b'\x8dT\xf7/;\xea\xad\x1c\x0c"\xcb\xe4\x08#L\x13sf \xcew\x04\xd0\xdd\x1a\x7f\xef\xe8\xfd\xf2\xdf\x13)\xdb\xcc\xa7\x85*U\xcfi\x0cXU\xe0\xe3I\xb4\x99\x95\x8e\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08']
Bad pipe message: %s [b'\n\x08\x0b\x08\x04\x08\x05\x08']
Bad pipe message: %s [b'\x01\x05\x01\x06\x01']
Bad pipe message: %s [b'\x06\x88\x11`\xb8;Cv\xf5\x9c\x8c\x90H\xfc!\xae1\x17 \x0e\xe6', b"]\xca\x87\xdd'dq=q?\t\xa8\x82X\x9f\x14k\x16\xc4\xfe\xf4P\x18\x91\x8b*\x8d\xde\x00\x08"]
Bad pipe message: %s [b"t\x91\xefu*\x17\xfe\x85L-\xfe\xd3~\xf9@\x82\x07.\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0