In [3]:
from transformers import set_seed, TrainingArguments, AutoTokenizer
import pandas as pd
set_seed(42)

In [74]:
import torch
from torch import nn
from transformers import BertPreTrainedModel

class ColbertModel(BertPreTrainedModel):

    def __init__(self, config):
        super(ColbertModel, self).__init__(config)

        #모델 수정 가능-현재는 기존 BertModel 사용중
        self.similarity_metric = 'cosine'
        self.dim = 128
        self.batch = 4
        self.bert = BertModel(config)
        self.init_weights()
        self.linear = nn.Linear(config.hidden_size, self.dim, bias=False)  


    def forward(self, q_inputs,c_inputs):
        Q = self.query(**q_inputs)
        D = self.doc(**c_inputs)
        return self.get_score(Q,D)


    def query(self, input_ids, attention_mask, token_type_ids):
        Q = self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)[0]
        Q = self.linear(Q)
        return torch.nn.functional.normalize(Q, p=2, dim=2)


    def doc(self, input_ids, attention_mask, token_type_ids):
        D = self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)[0]
        D = self.linear(D)
        return torch.nn.functional.normalize(D, p=2, dim=2)

    def get_score(self,Q,D,eval=False):
        if eval:
            if self.similarity_metric == 'cosine':
                print('1', end='')
                final_score=torch.tensor([])
                for D_batch in tqdm(D):
                    D_batch = torch.Tensor(D_batch).squeeze()
                    p_seqeunce_output=D_batch.transpose(1,2) #(batch_size,hidden_size,p_sequence_length)
                    q_sequence_output=Q.view(600 ,1,-1,self.dim) #(batch_size, 1, q_sequence_length, hidden_size)
                    dot_prod = torch.matmul(q_sequence_output,p_seqeunce_output) #(batch_size,batch_size, q_sequence_length, p_seqence_length)
                    max_dot_prod_score =torch.max(dot_prod, dim=3)[0] #(batch_size,batch_size,q_sequnce_length)
                    score = torch.sum(max_dot_prod_score,dim=2)#(batch_size,batch_size)
                    final_score = torch.cat([final_score,score],dim=1)
                print(final_score.size())
                return final_score

        else:
            if self.similarity_metric == 'cosine':
                print('2', end='')
                # print(D.shape)
                # print(Q.shape)
                p_seqeunce_output=D.transpose(1,2) #(batch_size,hidden_size,p_sequence_length)
                q_sequence_output=Q.view(self.batch,1,-1,self.dim) #(batch_size, 1, q_sequence_length, hidden_size)

                dot_prod = torch.matmul(q_sequence_output,p_seqeunce_output) #(batch_size,batch_size, q_sequence_length, p_seqence_length)
                max_dot_prod_score =torch.max(dot_prod, dim=3)[0] #(batch_size,batch_size,q_sequnce_length)
                final_score = torch.sum(max_dot_prod_score,dim=2)#(batch_size,batch_size)
            
                return final_score

    def score(self, Q, D):
        if self.similarity_metric == 'cosine':
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)
        assert self.similarity_metric == 'l2'
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)


In [75]:
from transformers import AutoConfig, BertModel

model_name = 'klue/bert-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens={'additional_special_tokens' :['[Q]','[D]']}
tokenizer.add_special_tokens(special_tokens)

args = TrainingArguments(
    num_train_epochs=5,
    weight_decay=0.1,
    output_dir = '../data/colbert',
    per_device_train_batch_size=4)
model_config = AutoConfig.from_pretrained(model_name)
model = ColbertModel.from_pretrained(model_name)
model.resize_token_embeddings(tokenizer.vocab_size + 2)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

Some weights of the model checkpoint at klue/bert-base were not used when initializing ColbertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing ColbertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColbertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColbertModel were not initialized from the model checkpoint at klue/bert-base and are newly initializ

ColbertModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(32002, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              

In [56]:
data = pd.read_csv('/opt/ml/final-project-level3-nlp-09/data/train.csv')[['question', 'context']]

In [57]:
def tokenize_for_train(data, tokenizer):
    
    questions = ['[Q] ' + question for question in list(data['question'])]
    contexts = ['[D] ' + context for context in list(data['context'])]

    tokenized_question = tokenizer( questions,
                                    return_tensors='pt',
                                    padding=True,
                                    truncation=True,
                                    max_length = 128 )
    tokenized_context = tokenizer(contexts,
                                  return_tensors='pt',
                                  padding='max_length',
                                  truncation=True)
    return tokenized_question, tokenized_context
    
tokenized_question, tokenized_context= tokenize_for_train(data, tokenizer)


In [58]:
from torch.utils.data import TensorDataset
train_data = TensorDataset(tokenized_question['input_ids'], tokenized_question['attention_mask'], tokenized_question['token_type_ids'],
tokenized_context['input_ids'], tokenized_context['attention_mask'], tokenized_context['token_type_ids'])

In [59]:
from torch.utils.data import (RandomSampler, DataLoader)
from torch.optim import AdamW
from tqdm import tqdm
import trange

In [60]:
#train 함수 분리 
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

In [77]:
import torch.nn.functional as F
from tqdm import tqdm
global_step = 0

model.zero_grad()
torch.cuda.empty_cache()

# train_iterator = trange( int(args.num_train_epochs))#, desc='Epoch')
for epoch in range(args.num_train_epochs) :
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")
    total_loss = 0
    steps = 0

    for step, data_set in enumerate(epoch_iterator):
        steps += 1
        model.train()

        if torch.cuda.is_available():
                data_set = tuple(t.cuda() for t in data_set)


        q_inputs = {'input_ids': data_set[0],
                    'attention_mask': data_set[1],
                    'token_type_ids': data_set[2]}

        c_inputs = { 'input_ids': data_set[3],
                      'attention_mask': data_set[4],
                      'token_type_ids': data_set[5] }
        
        outputs = model(q_inputs, c_inputs)
        
        targets = torch.arange(0, args.per_device_train_batch_size).long()
        if torch.cuda.is_available():
            targets = targets.to('cuda')
        
        sim_scores = F.log_softmax(outputs, dim=1)
        loss = F.nll_loss(sim_scores, targets)

        total_loss += loss
        loss.backward()
        # optimizer.step()
        # schedular.step()
        model.zero_grad()
        global_step += 1
        torch.cuda.empty_cache()
        
    torch.save(model.state_dict(), f'./colbert_models/{epoch+1}_model.pth')
    print(total_loss/step)


Iteration:   0%|          | 1/315 [00:00<02:12,  2.37it/s]

2

Iteration:   1%|          | 2/315 [00:00<01:27,  3.57it/s]

22

Iteration:   1%|▏         | 4/315 [00:00<01:06,  4.65it/s]

22

Iteration:   2%|▏         | 6/315 [00:01<01:00,  5.12it/s]

22

Iteration:   3%|▎         | 8/315 [00:01<00:57,  5.35it/s]

22

Iteration:   3%|▎         | 10/315 [00:02<00:56,  5.45it/s]

22

Iteration:   4%|▍         | 12/315 [00:02<00:55,  5.50it/s]

22

Iteration:   4%|▍         | 14/315 [00:02<00:54,  5.54it/s]

22

Iteration:   5%|▌         | 16/315 [00:03<00:53,  5.55it/s]

22

Iteration:   6%|▌         | 18/315 [00:03<00:53,  5.55it/s]

22

Iteration:   6%|▋         | 20/315 [00:03<00:53,  5.56it/s]

22

Iteration:   7%|▋         | 22/315 [00:04<00:52,  5.58it/s]

22

Iteration:   8%|▊         | 24/315 [00:04<00:52,  5.58it/s]

22

Iteration:   8%|▊         | 26/315 [00:04<00:51,  5.58it/s]

22

Iteration:   9%|▉         | 28/315 [00:05<00:51,  5.55it/s]

22

Iteration:  10%|▉         | 30/315 [00:05<00:51,  5.54it/s]

22

Iteration:  10%|█         | 32/315 [00:06<00:51,  5.54it/s]

22

Iteration:  11%|█         | 34/315 [00:06<00:50,  5.55it/s]

22

Iteration:  11%|█▏        | 36/315 [00:06<00:50,  5.57it/s]

22

Iteration:  12%|█▏        | 38/315 [00:07<00:49,  5.59it/s]

22

Iteration:  13%|█▎        | 40/315 [00:07<00:49,  5.57it/s]

22

Iteration:  13%|█▎        | 42/315 [00:07<00:49,  5.56it/s]

22

Iteration:  14%|█▍        | 44/315 [00:08<00:48,  5.56it/s]

22

Iteration:  15%|█▍        | 46/315 [00:08<00:48,  5.57it/s]

22

Iteration:  15%|█▌        | 48/315 [00:08<00:47,  5.58it/s]

22

Iteration:  16%|█▌        | 50/315 [00:09<00:47,  5.56it/s]

22

Iteration:  17%|█▋        | 52/315 [00:09<00:47,  5.54it/s]

22

Iteration:  17%|█▋        | 54/315 [00:09<00:46,  5.55it/s]

22

Iteration:  18%|█▊        | 56/315 [00:10<00:46,  5.56it/s]

22

Iteration:  18%|█▊        | 58/315 [00:10<00:46,  5.56it/s]

22

Iteration:  19%|█▉        | 60/315 [00:11<00:45,  5.57it/s]

22

Iteration:  20%|█▉        | 62/315 [00:11<00:45,  5.57it/s]

22

Iteration:  20%|██        | 64/315 [00:11<00:44,  5.58it/s]

22

Iteration:  21%|██        | 66/315 [00:12<00:44,  5.58it/s]

22

Iteration:  22%|██▏       | 68/315 [00:12<00:44,  5.57it/s]

22

Iteration:  22%|██▏       | 70/315 [00:12<00:43,  5.59it/s]

22

Iteration:  23%|██▎       | 72/315 [00:13<00:43,  5.56it/s]

22

Iteration:  23%|██▎       | 74/315 [00:13<00:43,  5.57it/s]

22

Iteration:  24%|██▍       | 76/315 [00:13<00:42,  5.57it/s]

22

Iteration:  25%|██▍       | 78/315 [00:14<00:42,  5.56it/s]

22

Iteration:  25%|██▌       | 80/315 [00:14<00:42,  5.58it/s]

22

Iteration:  26%|██▌       | 82/315 [00:14<00:41,  5.56it/s]

22

Iteration:  27%|██▋       | 84/315 [00:15<00:41,  5.56it/s]

22

Iteration:  27%|██▋       | 86/315 [00:15<00:40,  5.59it/s]

22

Iteration:  28%|██▊       | 88/315 [00:16<00:40,  5.57it/s]

22

Iteration:  29%|██▊       | 90/315 [00:16<00:40,  5.57it/s]

22

Iteration:  29%|██▉       | 92/315 [00:16<00:39,  5.58it/s]

22

Iteration:  30%|██▉       | 94/315 [00:17<00:39,  5.58it/s]

22

Iteration:  30%|███       | 96/315 [00:17<00:39,  5.58it/s]

22

Iteration:  31%|███       | 98/315 [00:17<00:38,  5.57it/s]

22

Iteration:  32%|███▏      | 100/315 [00:18<00:38,  5.57it/s]

22

Iteration:  32%|███▏      | 102/315 [00:18<00:38,  5.57it/s]

22

Iteration:  33%|███▎      | 104/315 [00:18<00:37,  5.56it/s]

22

Iteration:  34%|███▎      | 106/315 [00:19<00:37,  5.58it/s]

22

Iteration:  34%|███▍      | 108/315 [00:19<00:37,  5.59it/s]

22

Iteration:  35%|███▍      | 110/315 [00:20<00:36,  5.58it/s]

22

Iteration:  36%|███▌      | 112/315 [00:20<00:36,  5.60it/s]

22

Iteration:  36%|███▌      | 114/315 [00:20<00:36,  5.56it/s]

22

Iteration:  37%|███▋      | 116/315 [00:21<00:35,  5.57it/s]

22

Iteration:  37%|███▋      | 118/315 [00:21<00:35,  5.58it/s]

22

Iteration:  38%|███▊      | 120/315 [00:21<00:35,  5.57it/s]

22

Iteration:  39%|███▊      | 122/315 [00:22<00:34,  5.57it/s]

22

Iteration:  39%|███▉      | 124/315 [00:22<00:34,  5.56it/s]

22

Iteration:  40%|████      | 126/315 [00:22<00:33,  5.57it/s]

22

Iteration:  41%|████      | 128/315 [00:23<00:33,  5.58it/s]

22

Iteration:  41%|████▏     | 130/315 [00:23<00:33,  5.55it/s]

22

Iteration:  42%|████▏     | 132/315 [00:23<00:32,  5.55it/s]

22

Iteration:  43%|████▎     | 134/315 [00:24<00:32,  5.57it/s]

22

Iteration:  43%|████▎     | 136/315 [00:24<00:32,  5.56it/s]

22

Iteration:  44%|████▍     | 138/315 [00:25<00:32,  5.51it/s]

22

Iteration:  44%|████▍     | 140/315 [00:25<00:31,  5.53it/s]

22

Iteration:  45%|████▌     | 142/315 [00:25<00:31,  5.53it/s]

22

Iteration:  46%|████▌     | 144/315 [00:26<00:30,  5.53it/s]

22

Iteration:  46%|████▋     | 146/315 [00:26<00:30,  5.54it/s]

22

Iteration:  47%|████▋     | 148/315 [00:26<00:30,  5.48it/s]

22

Iteration:  48%|████▊     | 150/315 [00:27<00:29,  5.50it/s]

22

Iteration:  48%|████▊     | 152/315 [00:27<00:29,  5.54it/s]

22

Iteration:  49%|████▉     | 154/315 [00:27<00:29,  5.55it/s]

22

Iteration:  50%|████▉     | 156/315 [00:28<00:28,  5.53it/s]

22

Iteration:  50%|█████     | 158/315 [00:28<00:28,  5.55it/s]

22

Iteration:  51%|█████     | 160/315 [00:29<00:27,  5.57it/s]

22

Iteration:  51%|█████▏    | 162/315 [00:29<00:27,  5.58it/s]

22

Iteration:  52%|█████▏    | 164/315 [00:29<00:26,  5.59it/s]

22

Iteration:  53%|█████▎    | 166/315 [00:30<00:26,  5.57it/s]

22

Iteration:  53%|█████▎    | 168/315 [00:30<00:26,  5.58it/s]

22

Iteration:  54%|█████▍    | 170/315 [00:30<00:25,  5.60it/s]

22

Iteration:  55%|█████▍    | 172/315 [00:31<00:25,  5.61it/s]

22

Iteration:  55%|█████▌    | 174/315 [00:31<00:25,  5.60it/s]

22

Iteration:  56%|█████▌    | 176/315 [00:31<00:24,  5.60it/s]

22

Iteration:  57%|█████▋    | 178/315 [00:32<00:24,  5.58it/s]

22

Iteration:  57%|█████▋    | 180/315 [00:32<00:24,  5.56it/s]

22

Iteration:  58%|█████▊    | 182/315 [00:32<00:23,  5.55it/s]

22

Iteration:  58%|█████▊    | 184/315 [00:33<00:23,  5.54it/s]

22

Iteration:  59%|█████▉    | 186/315 [00:33<00:23,  5.52it/s]

22

Iteration:  60%|█████▉    | 188/315 [00:34<00:23,  5.51it/s]

22

Iteration:  60%|██████    | 190/315 [00:34<00:22,  5.52it/s]

22

Iteration:  61%|██████    | 192/315 [00:34<00:22,  5.53it/s]

22

Iteration:  62%|██████▏   | 194/315 [00:35<00:21,  5.53it/s]

22

Iteration:  62%|██████▏   | 196/315 [00:35<00:21,  5.53it/s]

22

Iteration:  63%|██████▎   | 198/315 [00:35<00:21,  5.52it/s]

22

Iteration:  63%|██████▎   | 200/315 [00:36<00:20,  5.52it/s]

22

Iteration:  64%|██████▍   | 202/315 [00:36<00:20,  5.52it/s]

22

Iteration:  65%|██████▍   | 204/315 [00:36<00:20,  5.53it/s]

22

Iteration:  65%|██████▌   | 206/315 [00:37<00:19,  5.53it/s]

22

Iteration:  66%|██████▌   | 208/315 [00:37<00:19,  5.53it/s]

22

Iteration:  67%|██████▋   | 210/315 [00:38<00:18,  5.55it/s]

22

Iteration:  67%|██████▋   | 212/315 [00:38<00:18,  5.56it/s]

22

Iteration:  68%|██████▊   | 214/315 [00:38<00:18,  5.58it/s]

22

Iteration:  69%|██████▊   | 216/315 [00:39<00:17,  5.58it/s]

22

Iteration:  69%|██████▉   | 218/315 [00:39<00:17,  5.57it/s]

22

Iteration:  70%|██████▉   | 220/315 [00:39<00:17,  5.50it/s]

22

Iteration:  70%|███████   | 222/315 [00:40<00:16,  5.54it/s]

22

Iteration:  71%|███████   | 224/315 [00:40<00:16,  5.55it/s]

22

Iteration:  72%|███████▏  | 226/315 [00:40<00:15,  5.56it/s]

22

Iteration:  72%|███████▏  | 228/315 [00:41<00:15,  5.54it/s]

22

Iteration:  73%|███████▎  | 230/315 [00:41<00:15,  5.53it/s]

22

Iteration:  74%|███████▎  | 232/315 [00:41<00:15,  5.53it/s]

22

Iteration:  74%|███████▍  | 234/315 [00:42<00:14,  5.52it/s]

22

Iteration:  75%|███████▍  | 236/315 [00:42<00:14,  5.52it/s]

22

Iteration:  76%|███████▌  | 238/315 [00:43<00:13,  5.52it/s]

22

Iteration:  76%|███████▌  | 240/315 [00:43<00:13,  5.52it/s]

22

Iteration:  77%|███████▋  | 242/315 [00:43<00:13,  5.51it/s]

22

Iteration:  77%|███████▋  | 244/315 [00:44<00:12,  5.51it/s]

22

Iteration:  78%|███████▊  | 246/315 [00:44<00:12,  5.51it/s]

22

Iteration:  79%|███████▊  | 248/315 [00:44<00:12,  5.51it/s]

22

Iteration:  79%|███████▉  | 250/315 [00:45<00:11,  5.49it/s]

22

Iteration:  80%|████████  | 252/315 [00:45<00:11,  5.49it/s]

22

Iteration:  81%|████████  | 254/315 [00:45<00:11,  5.50it/s]

22

Iteration:  81%|████████▏ | 256/315 [00:46<00:10,  5.50it/s]

22

Iteration:  82%|████████▏ | 258/315 [00:46<00:10,  5.50it/s]

22

Iteration:  83%|████████▎ | 260/315 [00:47<00:10,  5.50it/s]

22

Iteration:  83%|████████▎ | 262/315 [00:47<00:09,  5.50it/s]

22

Iteration:  84%|████████▍ | 264/315 [00:47<00:09,  5.51it/s]

22

Iteration:  84%|████████▍ | 266/315 [00:48<00:08,  5.50it/s]

22

Iteration:  85%|████████▌ | 268/315 [00:48<00:08,  5.51it/s]

22

Iteration:  86%|████████▌ | 270/315 [00:48<00:08,  5.50it/s]

22

Iteration:  86%|████████▋ | 272/315 [00:49<00:07,  5.49it/s]

22

Iteration:  87%|████████▋ | 274/315 [00:49<00:07,  5.50it/s]

22

Iteration:  88%|████████▊ | 276/315 [00:49<00:07,  5.50it/s]

22

Iteration:  88%|████████▊ | 278/315 [00:50<00:06,  5.48it/s]

22

Iteration:  89%|████████▉ | 280/315 [00:50<00:06,  5.49it/s]

22

Iteration:  90%|████████▉ | 282/315 [00:51<00:06,  5.49it/s]

22

Iteration:  90%|█████████ | 284/315 [00:51<00:05,  5.50it/s]

22

Iteration:  91%|█████████ | 286/315 [00:51<00:05,  5.50it/s]

22

Iteration:  91%|█████████▏| 288/315 [00:52<00:04,  5.49it/s]

22

Iteration:  92%|█████████▏| 290/315 [00:52<00:04,  5.50it/s]

22

Iteration:  93%|█████████▎| 292/315 [00:52<00:04,  5.49it/s]

22

Iteration:  93%|█████████▎| 294/315 [00:53<00:03,  5.43it/s]

22

Iteration:  94%|█████████▍| 296/315 [00:53<00:03,  5.46it/s]

22

Iteration:  95%|█████████▍| 298/315 [00:53<00:03,  5.48it/s]

22

Iteration:  95%|█████████▌| 300/315 [00:54<00:02,  5.47it/s]

22

Iteration:  96%|█████████▌| 302/315 [00:54<00:02,  5.48it/s]

22

Iteration:  97%|█████████▋| 304/315 [00:55<00:02,  5.49it/s]

22

Iteration:  97%|█████████▋| 306/315 [00:55<00:01,  5.49it/s]

22

Iteration:  98%|█████████▊| 308/315 [00:55<00:01,  5.50it/s]

22

Iteration:  98%|█████████▊| 310/315 [00:56<00:00,  5.48it/s]

22

Iteration:  99%|█████████▉| 312/315 [00:56<00:00,  5.49it/s]

22

Iteration: 100%|█████████▉| 314/315 [00:56<00:00,  5.51it/s]

22




RuntimeError: shape '[4, 1, -1, 128]' is invalid for input of size 7936

In [31]:
from transformers import BertPreTrainedModel

class Retrieval_Model(BertPreTrainedModel):
    def __init__(self, config):
        super(Retrieval_Model, self).__init__(config)

        #모델 수정 가능-현재는 기존 BertModel 사용중
        self.bert = BertModel(config)
        self.init_weights()
      
    def forward(self, input_ids, 
                attention_mask=None, token_type_ids=None): 
    
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        
        pooled_output = outputs[1]

        return pooled_output

In [None]:
from transformers import AutoConfig, BertModel
model_config = AutoConfig.from_pretrained(Model_Name)
q_encoder = Retrieval_Model.from_pretrained(Model_Name)
c_encoder = Retrieval_Model.from_pretrained(Model_Name)
q_encoder.to(device)
c_encoder.to(device)

In [36]:


train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

# nlp-11 Colbert 파라미터 참고
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
        {'params': [p for n, p in q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': Args.weight_decay},
        {'params': [p for n, p in q_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in c_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': Args.weight_decay},
        {'params': [p for n, p in c_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
optimizer = AdamW(optimizer_grouped_parameters, lr=Args.learning_rate, eps=arsgs.adam_epcilon)

q_encoder.zero_grad()
c_encoder.zero_grad()
torch.cuda.empty_cache()


for epoch in trange( int(Args.Num_train_epochs), desc='Epoch'):
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")
    total_loss = 0
    steps = 0

    for step, data_set in enumerate(epoch_iterator):
        steps += 1
        q_encoder.train()
        c_encoder.train()

        q_inputs = {'input_ids': data_set[0],
                    'attention_mask': data_set[1],
                    'token_type_ids': data_set[2]}
        c_inputs = { 'input_ids': data_set[3],
                      'attention_maks': data_set[4],
                      'token_type_ids': data_set[5] }
        
        q_out = q_encoder(**p_inputs)
        c_out = c_encoder(**c_inputs)





#  q_encoder,c_encoder = train(args,train_dataset,p_encoder,q_encoder)

IndentationError: unexpected indent (1572693031.py, line 14)

In [46]:
from torch.utils.data import (RandomSampler, DataLoader)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=Args.per_device_train_batch_size)
train_sampler = RandomSampler(train_data)

for step, batch in enumerate(train_dataloader):
    # steps += 1
    # q_encoder.train()
    # c_encoder.train()
    print(step)
    print('-----------------------')
    s = batch
    break

    # p_inputs = {'input_ids': batch[0],
    #             'attention_mask': batch[1],
    #             'token_type_ids': batch[2]}
    # q_inputs = { 'input_ids': batch[3],
    #                 'attention_maks': batch[4],
    #                 'token_type_ids': batch[5] }

0
-----------------------


tensor([[    2,  1545, 26219,  ...,     0,     0,     0],
        [    2, 16039,  2242,  ...,     0,     0,     0],
        [    2,  1545,  2196,  ...,     0,     0,     0],
        ...,
        [    2,  5787,  2440,  ...,     0,     0,     0],
        [    2,  1545, 19621,  ...,     0,     0,     0],
        [    2,  6515,  2440,  ...,     0,     0,     0]])

In [42]:
\

<torch.utils.data.sampler.RandomSampler at 0x7fdad08518e0>