In [1]:
import os
import math
import time

from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig

import bittensor
import torch 
from torch.utils.data import DataLoader
import wandb
import datasets
from datasets import load_from_disk, concatenate_datasets
from accelerate import Accelerator
import transformers
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    get_scheduler,
)

has no attribute 'buffer'


In [2]:
def check_cfg_and_load_defaults(cfg: DictConfig) -> DictConfig:

    subtensor = bittensor.subtensor(network=cfg.bittensor.network)
    if cfg.dataset.block_size is None:
        cfg.dataset.block_size = subtensor.validator_sequence_length
    if cfg.training.train_batch_size is None:
        cfg.training.train_batch_size = subtensor.validator_batch_size
    if cfg.training.eval_batch_size is None:
        cfg.training.eval_batch_size = subtensor.validator_batch_size

    return cfg

In [49]:
def load_dataset(cfg: DictConfig):
    status = True
    i = 0

    while status:
        file_name = cfg.dataset.file_name + "_" + str(i)
        data_file = os.path.join(cfg.dataset.data_dir, file_name)

        try:
            tokenized_dataset_batch = load_from_disk(data_file)
            tokenized_dataset_batch.set_format(type='pt')
            print(f"loaded data from {data_file}.")
        except:
            status = False
            print(f"{data_file} doesn't exist.")

        if i==0:
            tokenized_dataset = tokenized_dataset_batch
        else:
            tokenized_dataset = concatenate_datasets([tokenized_dataset, tokenized_dataset_batch])
        
        i += 1

    return tokenized_dataset

In [4]:
def load_tokenizer(cfg: DictConfig):
    if cfg.tokenizer.name is not None:
        tokenizer = AutoTokenizer.from_pretrained(
            cfg.tokenizer.name, use_fast=cfg.tokenizer.use_fast
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            cfg.model.name, use_fast=cfg.tokenizer.use_fast
        )
    
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    return tokenizer

In [5]:
# Accelerate config file at: /home/paperspace/.cache/huggingface/accelerate/default_config.yaml

# compute_environment: LOCAL_MACHINE
# distributed_type: MULTI_GPU
# downcast_bf16: 'no'
# gpu_ids: all
# machine_rank: 0
# main_process_ip: ''
# main_process_port: 29500
# main_training_function: main
# mixed_precision: 'no'
# num_machines: 4
# num_processes: 4
# rdzv_backend: static
# same_network: true
# tpu_env: []
# tpu_use_cluster: false
# tpu_use_sudo: false
# use_cpu: false

In [6]:
# import accelerate

# accelerate.load_state(input_dir="/home/paperspace/.cache/huggingface/accelerate/default_config.yaml")

In [7]:
def create_accelerator(cfg: DictConfig) -> Accelerator:

    accelerator = (
        Accelerator(log_with=cfg.tracking.report_to, project_dir=cfg.output_dir)
        if cfg.tracking.enabled
        else Accelerator()
    )
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    return accelerator

In [8]:
def load_model(cfg: DictConfig, tokenizer):

    if cfg.model.config_name is not None:
        config = AutoConfig.from_pretrained(cfg.model.config_name)
    else:
        config = AutoConfig.from_pretrained(cfg.model.name)

    model = AutoModelForCausalLM.from_pretrained(
        cfg.model.name,
        from_tf=bool(".ckpt" in cfg.model.name),
        config=config,
    )
    model.resize_token_embeddings(len(tokenizer))

    return model

In [9]:
def create_optimizer(cfg, model):

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": cfg.training.weight_decay,
        },
        {
            "params": [
                p
                for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    return torch.optim.AdamW(
        optimizer_grouped_parameters, lr=cfg.training.learning_rate
    )

In [10]:
def set_seed(seed=17):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

In [50]:
# print('a')
# accelerator = create_accelerator(cfg)
# print('b')
# accelerator.wait_for_everyone()
# print('c')
cfg = OmegaConf.load('conf/config.yaml')
cfg = check_cfg_and_load_defaults(cfg)

tokenized_datasets = load_dataset(cfg)
tokenizer = load_tokenizer(cfg)

/mnt/share/ipfs-data/data/tokenized_data_0 doesn't exist.


In [14]:
if "train" not in tokenized_datasets.column_names:
    tokenized_datasets = tokenized_datasets.train_test_split(
        test_size=cfg.training.val_split_percent / 100
    )
    tokenized_datasets_test_valid = tokenized_datasets["test"].train_test_split(
        test_size=0.5
    )
    tokenized_datasets["test"] = tokenized_datasets_test_valid["train"]
    tokenized_datasets["validation"] = tokenized_datasets_test_valid["test"]
    
print(f"train shape {tokenized_datasets['train'].shape}")
print(f"eval shape {tokenized_datasets['validation'].shape}")
print(f"test shape {tokenized_datasets['test'].shape}")

train shape (6758, 4)
eval shape (845, 4)
test shape (845, 4)


In [15]:
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["validation"]
test_dataset = tokenized_datasets["test"]

In [16]:
train_dataset

Dataset({
    features: ['text', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 6758
})

In [20]:
train_dataset.column_names

['text', 'input_ids', 'attention_mask', 'labels']

In [34]:
def describe_tensor(dataset):
    for column in dataset.column_names:
        print(f"type of {column} is {type(dataset[column])}")
        # print(dataset[column][0])

        # dataset[column].to('cuda:0')

In [47]:
def move_to_tensor(dataset):
    for column in dataset.column_names:
        if column != 'text':
            dataset[column] = torch.tensor(dataset[column]).to('cuda:0')
        # print(f"type of {column} is {type(dataset[column])}")
        # print(dataset[column][0])

In [37]:
describe_tensor(train_dataset)

type of text is <class 'list'>
type of input_ids is <class 'list'>
type of attention_mask is <class 'list'>
type of labels is <class 'list'>


In [48]:
move_to_tensor(train_dataset)

In [42]:
describe_tensor(train_dataset)

type of text is <class 'list'>
type of input_ids is <class 'list'>
type of attention_mask is <class 'list'>
type of labels is <class 'list'>


In [14]:
# train_dataloader = DataLoader(
#     train_dataset,
#     shuffle=True,
#     collate_fn=default_data_collator,
#     batch_size=cfg.training.train_batch_size,
# )

# eval_dataloader = DataLoader(
#     eval_dataset,
#     collate_fn=default_data_collator,
#     batch_size=cfg.training.eval_batch_size,
# )

# test_dataloader = DataLoader(
#     test_dataset,
#     collate_fn=default_data_collator,
#     batch_size=cfg.training.eval_batch_size,
# )

In [15]:
from torch import Tensor
import numpy as np 

def attention_zeros(attn) -> Tensor:
    attn_zeros = attn
    idxs = np.random.randint(0, len(attn[0]-1), len(attn))
    
    for i, k in enumerate(zip(idxs, attn)):
        attn_zeros[i][k[1]][k[0]] = 0

    return attn_zeros

In [16]:
# (
#   model,
#   optimizer,
#     train_dataloader,
#     eval_dataloader,
#     lr_scheduler,
# ) = accelerator.prepare(
#         model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
#     )

In [17]:
# # Scheduler and math around the number of training steps.
# overrode_max_train_steps = False
# num_update_steps_per_epoch = math.ceil(
#     len(train_dataloader) / cfg.training.gradient_accumulation_steps
# )
# if cfg.training.max_train_steps is None:
#     cfg.training.max_train_steps = (
#         cfg.training.num_epochs * num_update_steps_per_epoch
#     )
#     overrode_max_train_steps = True

# # We need to recalculate our total training steps as the size of the training dataloader
# # may have changed.
# # num_update_steps_per_epoch = math.ceil(
# #     len(train_dataloader) / cfg.training.gradient_accumulation_steps
# # )
# # if overrode_max_train_steps:
# #     cfg.training.max_train_steps = (
# #         cfg.training.num_epochs * num_update_steps_per_epoch
# #     )
# # Afterwards we recalculate our number of training epochs
# cfg.training.num_epochs = math.ceil(
#     cfg.training.max_train_steps / num_update_steps_per_epoch
# )

In [18]:
from tqdm.auto import tqdm
import logging
from accelerate.logging import get_logger

In [19]:
# from tqdm.auto import tqdm
# import logging
# from accelerate.logging import get_logger

# if cfg.tracking.enabled is True and accelerator.is_main_process:
#     experiment_config = vars(cfg)
#     # TensorBoard cannot log Enums, need the raw value
#     experiment_config["lr_scheduler_type"] = cfg.training.lr_scheduler
#     accelerator.init_trackers("finetune_using_clm", experiment_config)

# logger = get_logger(__name__)
# logging.basicConfig(
#     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
#     datefmt="%m/%d/%Y %H:%M:%S",
#     level=logging.INFO,
# )
    
# logger.info("***** Running training *****")
# logger.info(f"  Num examples = {len(train_dataset)}")
# logger.info(f"  Num Epochs = {cfg.training.num_epochs}")
# logger.info(
#     f"  Gradient Accumulation steps = {cfg.training.gradient_accumulation_steps}"
# )
# logger.info(f"  Total optimization steps = {cfg.training.max_train_steps}")

# # Only show the progress bar once on each machine.
# progress_bar = tqdm(
#     range(cfg.training.max_train_steps),
#     disable=not accelerator.is_local_main_process,
# )

In [20]:
os.environ["CUDA_VISIBLE_DEVICES"]= "0, 1, 2, 3"

In [28]:
def training_loop():
    
    cfg = OmegaConf.load('conf/config.yaml')
    cfg = check_cfg_and_load_defaults(cfg)

    tokenized_datasets = load_dataset(cfg)
    tokenizer = load_tokenizer(cfg)

    print('a')
    accelerator = create_accelerator(cfg)
    print('b')
    accelerator.wait_for_everyone()
    print('c')

    model = load_model(cfg, tokenizer)
    optimizer = create_optimizer(cfg, model)
    lr_scheduler = get_scheduler(
        name=cfg.training.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=cfg.training.lr_warmup_steps,
        num_training_steps=cfg.training.max_train_steps,
    )


    if "train" not in tokenized_datasets.column_names:
        tokenized_datasets = tokenized_datasets.train_test_split(
            test_size=cfg.training.val_split_percent / 100
        )
        tokenized_datasets_test_valid = tokenized_datasets["test"].train_test_split(
            test_size=0.5
        )
        tokenized_datasets["test"] = tokenized_datasets_test_valid["train"]
        tokenized_datasets["validation"] = tokenized_datasets_test_valid["test"]
        
    print(f"train shape {tokenized_datasets['train'].shape}")
    print(f"eval shape {tokenized_datasets['validation'].shape}")
    print(f"test shape {tokenized_datasets['test'].shape}")
    

    train_dataset = tokenized_datasets["train"]
    eval_dataset = tokenized_datasets["validation"]
    test_dataset = tokenized_datasets["test"]


    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=default_data_collator,
        batch_size=cfg.training.train_batch_size,
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        collate_fn=default_data_collator,
        batch_size=cfg.training.eval_batch_size,
    )

    test_dataloader = DataLoader(
        test_dataset,
        collate_fn=default_data_collator,
        batch_size=cfg.training.eval_batch_size,
    )



    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / cfg.training.gradient_accumulation_steps
    )
    if cfg.training.max_train_steps is None:
        cfg.training.max_train_steps = (
            cfg.training.num_epochs * num_update_steps_per_epoch
        )
        overrode_max_train_steps = True

    # We need to recalculate our total training steps as the size of the training dataloader
    # may have changed.
    # num_update_steps_per_epoch = math.ceil(
    #     len(train_dataloader) / cfg.training.gradient_accumulation_steps
    # )
    # if overrode_max_train_steps:
    #     cfg.training.max_train_steps = (
    #         cfg.training.num_epochs * num_update_steps_per_epoch
    #     )
    # Afterwards we recalculate our number of training epochs
    cfg.training.num_epochs = math.ceil(
        cfg.training.max_train_steps / num_update_steps_per_epoch
    )


    if cfg.tracking.enabled is True and accelerator.is_main_process:
        experiment_config = vars(cfg)
        # TensorBoard cannot log Enums, need the raw value
        experiment_config["lr_scheduler_type"] = cfg.training.lr_scheduler
        accelerator.init_trackers("finetune_using_clm", experiment_config)

    logger = get_logger(__name__)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
        
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {cfg.training.num_epochs}")
    logger.info(
        f"  Gradient Accumulation steps = {cfg.training.gradient_accumulation_steps}"
    )
    logger.info(f"  Total optimization steps = {cfg.training.max_train_steps}")

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(
        range(cfg.training.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )



    completed_steps = 0
    starting_epoch = 0

    # Add resume training functionality

    epoch_durations = []
    train_checkpoint_durations = []

    with wandb.init(project=cfg.project_name, config=cfg, name=cfg.model.name):
        for epoch in range(starting_epoch, cfg.training.num_epochs):
            epoch_start_time = time.time()
            model.train()
            # if cfg.tracking.enabled is True:
            total_loss = 0
            train_losses = []
            train_checkpoint_start_time = time.time()

            for step, batch in enumerate(train_dataloader):
                # We need to skip steps until we reach the resumed step
                # if (
                #     cfg.training.checkpoint.resume_from_checkpoint
                #     and epoch == starting_epoch
                # ):
                #     if resume_step is not None and step < resume_step:
                #         completed_steps += 1
                #         continue
                
                # Set random token to 0 attention
                tmp_attn = batch['attention_mask']
                batch['attention_mask'] = attention_zeros(tmp_attn)

                outputs = model(**batch)
                loss = outputs.loss
                train_losses.append(
                    accelerator.gather(loss.repeat(cfg.training.train_batch_size))
                )
            # We keep track of the loss at each epoch
                if cfg.tracking.enabled is True:
                    total_loss += loss.detach().float()
                loss = loss / cfg.training.gradient_accumulation_steps
                accelerator.backward(loss)

                if (
                        step % cfg.training.gradient_accumulation_steps == 0
                        or step == len(train_dataloader) - 1
                    ):
                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()
                        progress_bar.update(1)
                        completed_steps += 1


                if step % cfg.training.eval_every == 0:
                    

                    train_losses_tensor = torch.cat(train_losses)
                    train_loss = torch.mean(train_losses_tensor)
                    model.eval()
                    eval_losses = []
                    for _eval_step, eval_batch in enumerate(eval_dataloader):
                        with torch.no_grad():
                            outputs = model(**eval_batch)

                        loss = outputs.loss
                        eval_losses.append(
                            accelerator.gather(loss.repeat(cfg.training.eval_batch_size))
                        )

                    losses = torch.cat(eval_losses)
                    losses = losses[: len(eval_dataset)]
                    try:
                        eval_loss = torch.mean(losses)
                        perplexity = math.exp(eval_loss)
                    except OverflowError:
                        perplexity = float("inf")

                    train_checkpoint_duration = time.time() - train_checkpoint_start_time
                    train_checkpoint_durations.append(train_checkpoint_duration)
                    
                    wandb.log({"train_loss": train_loss, "epoch": epoch, 'eval_loss': eval_loss, 'eval_perplexity': perplexity, 'train_checkpoint_duration': train_checkpoint_duration}, step=step)
                    logger.info(
                        f"epoch {epoch}: eval_perplexity: {perplexity} train_loss: {train_loss} eval_loss: {eval_loss} 'train_checkpoint_duration': {train_checkpoint_duration} step: {step}"
                    )

                    train_checkpoint_start_time = time.time()

                    # epoch_dir = f"epoch_{epoch}_most_recent"
                    # if cfg.output_dir is not None:
                    #     output_dir = os.path.join(cfg.output_dir, epoch_dir)
                    # unwrapped_model = accelerator.unwrap_model(model)
                    # unwrapped_model.save_pretrained(
                    #     output_dir,
                    #     is_main_process=accelerator.is_main_process,
                    #     save_function=accelerator.save,
                    # )
                    # if accelerator.is_main_process:
                    #     tokenizer.save_pretrained(output_dir)

            train_loss = total_loss.item() / len(train_dataloader)

            # Below was causing WandB communication errors
            # if cfg.tracking.enabled is True:
            #     accelerator.log(
            #         {
            #             "perplexity": perplexity,
            #             "eval_loss": eval_loss,
            #             "train_loss": train_loss,
            #             "epoch": epoch,
            #             "step": completed_steps,
            #         },
            #         step=completed_steps,
            #     )


            epoch_duration = time.time() - epoch_start_time
            epoch_durations.append(epoch_duration)

            wandb.log({"train_loss": train_loss, "epoch": epoch, 'eval_loss': eval_loss, 'eval_perplexity': perplexity, 'epoch_duration': epoch_duration}, step=step)
            logger.info(f"done epoch {epoch}")


        avg_epoch_runtime = sum(epoch_durations) / len(epoch_durations)
        avg_train_checkpoint_runtime = sum(train_checkpoint_durations) / len(train_checkpoint_durations)

        wandb.log({"avg epoch runtime (seconds)": avg_epoch_runtime})
        wandb.log({"avg train checkpoint runtime (seconds)": avg_train_checkpoint_runtime})
        wandb.finish()

    if cfg.output_dir is not None:
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(
            cfg.output_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
        )
        if accelerator.is_main_process:
            tokenizer.save_pretrained(cfg.output_dir)

In [29]:
from accelerate import notebook_launcher

notebook_launcher(training_loop, num_processes=4, use_port=29504)

Launching training on 4 GPUs.
loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.
/mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist.
loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.
/mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist.
loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.
/mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist.
loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.
/mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist.


Using pad_token, but it is not set yet.


a


Using pad_token, but it is not set yet.


a


Using pad_token, but it is not set yet.


a


Using pad_token, but it is not set yet.


a


In [27]:
from accelerate import notebook_launcher

notebook_launcher(training_loop, num_processes=4)

Launching training on 4 GPUs.
loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.loaded data from /mnt/share/ipfs-data/data/tokenized_data_0.



/mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist./mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist./mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist./mnt/share/ipfs-data/data/tokenized_data_1 doesn't exist.





Using pad_token, but it is not set yet.
Using pad_token, but it is not set yet.
Using pad_token, but it is not set yet.
Using pad_token, but it is not set yet.


In [None]:
for step, batch in enumerate(train_dataloader):
    # We need to skip steps until we reach the resumed step
    # if (
    #     cfg.training.checkpoint.resume_from_checkpoint
    #     and epoch == starting_epoch
    # ):
    #     if resume_step is not None and step < resume_step:
    #         completed_steps += 1
    #         continue

    print(batch)
    print(f"length batch {len(batch['attention_mask'][0])}")
    attn = batch['input_ids']
# print(**batch)

    break
    # outputs = model(**batch)

{'input_ids': tensor([[ 5562,   717,  2239,  ..., 26350,   329,  1811],
        [ 2435,   290,   345,  ...,   326,   339,   550],
        [ 1662,  3051,   287,  ...,    11,   810,   262],
        ...,
        [ 1462,   423,   640,  ..., 10846,  3332,   287],
        [13893,   356,   550,  ...,   503,   612,   379],
        [  505,  1517,   878,  ...,   503,   465,  4324]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[ 5562,   717,  2239,  ..., 26350,   329,  1811],
        [ 2435,   290,   345,  ...,   326,   339,   550],
        [ 1662,  3051,   287,  ...,    11,   810,   262],
        ...,
        [ 1462,   423,   640,  ..., 10846,  3332,   287],
        [13893,   356,   550,  ...,   503,   612,   379],
        [  505,  1517,   878,  ...,   503,   465,  4324]])}
length batch 

In [None]:
attn.size()

torch.Size([32, 256])

In [None]:
len(train_dataloader)*3

22971

In [None]:
import random
import numpy as np
hi = np.random.randint(0, 255, 32)

In [None]:
len(hi)

32

In [None]:
attn_z = attention_zeros(attn)

In [None]:
print(attn[0])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [None]:
print(attn_z[0])

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [None]:
import torch

In [None]:
len(attn)

32

In [None]:
attn_zeros = attn

for i, k in enumerate(zip(hi, attn)):
    # 
    # print(k)
    attn_zeros[i][k[1]][k[0]] = 0
    # print(k)
    # print(k[0])
    # print(k[1])
    # print(k[hi])
    # break
    # print(k)

# new_attn = atten

In [None]:
attn

tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])

In [None]:
attn_zeros[1]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])