In [1]:
import json, typing, gc, os, sys,string
import torch
from torch import nn
import torch.nn.functional as F
import transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer
import simpletransformers
import numpy as np
import pandas as pd
import time
from simpletransformers.question_answering import QuestionAnsweringModel, QuestionAnsweringArgs
import tokenizers
from tokenizers import BertWordPieceTokenizer
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from collections import Counter

In [2]:
def simpletransformers_qa():
    
    with open('MLQA_V1/test/test-context-hi-question-hi.json','r',encoding='utf-8') as f:
        dataset = json.load(f)
    data = dataset['data']

    mlqa_test = []
    for para in data:
        paras = para['paragraphs']
        for i in range(len(paras)):
            mlqa_test.append(paras[i])


    model_args = QuestionAnsweringArgs()
    model_args.fp16 = False
    model_args.train_custom_parameters_only = False
    model_args.warmup_steps = 3
    model_args.overwrite_output_dir = True
    model = QuestionAnsweringModel('distilbert', 'hi-lm-distilbert/',args=model_args)
    model.train_model(mlqa_test[:16])

In [3]:
def load_json(path):
    '''
    Loads the JSON file of the Squad dataset.
    Returns the json object of the dataset.
    '''
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        
    print("Length of data: ", len(data['data']))
    print("Data Keys: ", data['data'][0].keys())
    print("Title: ", data['data'][0]['title'])
    
    return data


def parse_data(data:dict)->list:
    '''
    Parses the JSON file of Squad dataset by looping through the
    keys and values and returns a list of dictionaries with
    context, query and label triplets being the keys of each dict.
    '''
    data = data['data']
    qa_list = []

    for paragraphs in data:

        for para in paragraphs['paragraphs']:
            context = para['context']

            for qa in para['qas']:
                
                id = qa['id']
                question = qa['question']
                
                for ans in qa['answers']:
                    answer = ans['text']
                    ans_start = ans['answer_start']
                    ans_end = ans_start + len(answer)
                    
                    qa_dict = {}
                    qa_dict['id'] = id
                    qa_dict['context'] = context
                    qa_dict['question'] = question
                    qa_dict['label'] = [ans_start, ans_end]

                    qa_dict['answer'] = answer
                    qa_list.append(qa_dict)    

    return qa_list

In [4]:
mlqa_test_data = load_json('MLQA_V1/test/test-context-hi-question-hi.json')
mlqa_dev_data = load_json('MLQA_V1/dev/dev-context-hi-question-hi.json')
xsquad_data = load_json('xsquad-hi.json')

qa_list = parse_data(mlqa_test_data) + parse_data(mlqa_dev_data) + parse_data(xsquad_data)
print(len(qa_list))

Length of data:  2038
Data Keys:  dict_keys(['title', 'paragraphs'])
Title:  एरिया 51
Length of data:  217
Data Keys:  dict_keys(['title', 'paragraphs'])
Title:  संयंत्र सेल
Length of data:  48
Data Keys:  dict_keys(['paragraphs', 'title'])
Title:  Super_Bowl_50
6615


In [6]:
df = pd.DataFrame(qa_list)


In [7]:
df.head()

Unnamed: 0,id,context,question,label,answer
0,eeb8dbd25efe5221dc6723ddee95daa07d2c8478,"उसी ""एरिया XX "" नामकरण प्रणाली का प्रयोग नेवाद...",झील के सापेक्ष ग्रूम लेक रोड कहाँ जाती थी?,"[378, 389]",उत्तर पूर्व
1,ba7865d50777f2b90ba88fcb070a672d042b6b69,"में खानों की ओर जाती थीं, लेकिन उनके बंद होने ...",किस प्रकार की सड़कें बड़े खेतों और पशु-फार्मों त...,"[308, 316]",डर्ट-रोड
2,2079cf7ce47961738e4bd0d527d0b1058210f869,विश्व युद्ध II के दौरान ग्रूम झील का प्रयोग बम...,विमान के लिए परीक्षण पट्टी क्या बनी?,"[237, 247]",झील की सतह
3,d5377da63e6f64dae5e269290a6334c2a912cb3f,लॉकहीड ने इस स्थल पर एक अस्थायी अड्डे का निर्म...,प्रारंभिक u-2 वितरण के साथ कौन था?,"[330, 347]",लॉकहीड विशेषज्ञों
4,03df1f92420416844575cfa201ae840319c40650,"अधिकांश नेल्लिस सीमा के विपरीत, झील के आस-पास ...",प्रतिबंधित क्षेत्रों में भटकते पर सैन्य पायलटो...,"[366, 378]",अनुशासनात्मक


In [8]:
class SquadDataset:
    
    def __init__(self, tokenizer, context, answer, question, label, question_ids, max_len):
        
        self.context = context
        self.answer = answer
        self.question = question
        self.tokenizer = tokenizer
        self.question_ids = question_ids
        self.max_len = max_len
        self.label = label
        
    def __len__(self):
        return len(self.question)
    
    def __getitem__(self, item):
        
        context = self.context[item]
        question = self.question[item]
        answer = self.answer[item]
        question_id = self.question_ids[item]
        label = self.label[item]
        long_example = False
        ans_len = len(answer)
        start_idx, end_idx = label[0], label[1]
        
#         for idx in (i for i, ch in enumerate(context) if ch == answer[0]):
            
#             if context[idx : idx+ans_len] == answer:
#                 start_idx = idx
#                 end_idx = idx + ans_len - 1
#                 break
        
        char_targets = [0] * len(context)
        if start_idx != -1 and end_idx != -1:
            for i in range(start_idx, end_idx):
                    char_targets[i] = 1
                    
        # [000000111111111000000]
        
        tokenized_ctx = self.tokenizer.encode(context)
        context_tokens = tokenized_ctx.tokens
        
        #omit CLS and SEP
        context_ids = tokenized_ctx.ids[1:-1]
        
        #omit CLS and SEP
        context_offsets = tokenized_ctx.offsets[1:-1]
        # [(0,0), (0,3), (4,6) ... ]
        label_idx = []
        for i, (offset1, offset2) in enumerate(context_offsets):
            if sum(char_targets[offset1: offset2]) > 0:
                label_idx.append(i)
                
        start_idx = label_idx[0]
        end_idx = label_idx[-1]
        
        tokenized_qtn = self.tokenizer.encode(question)
        question_tokens = tokenized_qtn.tokens
        question_ids = tokenized_qtn.ids[1:-1]
        question_offsets = tokenized_qtn.offsets[1:-1]
        
        CLS = [2]
        SEP = [3]
        
        input_ids = CLS + question_ids + SEP + context_ids + SEP 
        token_type_ids = [0] * (len(question_ids) + 2) + [1] * (len(context_ids) + 1)
        mask = [1] * len(token_type_ids)
        start_idx += len(question_ids) + 2
        end_idx += len(question_ids) + 2
        offsets = [(0,0)] * 2 + question_offsets + context_offsets + [(0,0)]
        
        if len(input_ids) > self.max_len:
            long_example = True
            
        
        padding_len = self.max_len - len(input_ids)
        if padding_len > 0:
            input_ids = input_ids + ([0] * padding_len)
            mask = mask + ([0] * padding_len)
            token_type_ids = token_type_ids + ([0] * padding_len)
            offsets = offsets + ([(0,0)] * padding_len)

        
        
        return {
            'input_ids':torch.tensor(input_ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'start_idx': torch.tensor(start_idx, dtype=torch.long),
            'end_idx': torch.tensor(end_idx, dtype=torch.long),
            'question_id':question_id,
            'context':context,
            'answer':answer,
            'offsets':torch.tensor(offsets, dtype=torch.long),
            'long_ex':long_example
        }        

In [9]:
tokenizer = BertWordPieceTokenizer('hi-lm-distilbert/vocab.txt')

In [10]:
dataset =  SquadDataset(tokenizer, 
                        df.context.values, 
                        df.answer.values, 
                        df.question.values,
                        df.label.values,
                        df.id.values, 
                        512)

In [11]:
def filter_long_examples(data):
    
    long_examples = []
    for i in range(len(data)):
        if dataset[i]['long_ex'] == True:
            long_examples.append(i)
    
    return long_examples

In [12]:
long_examples = filter_long_examples(dataset)

In [14]:
len(long_examples)

318

In [11]:
def test_dataset(data):
    error_indices = []
    for i in range(len(data)):
        try:
            x = data[i]
        except:
            error_indices.append(i)

    if len(error_indices) == 0:
        print("Test passed succesfully")
    else:
        print(f"Error indices: {error_indices}")
        return error_indices

In [12]:
idx = test_dataset(dataset)

Test passed succesfully


In [13]:
def test_answers(data, tokenizer):
    
    error_indices = []
    for i,example in enumerate(data):
        input_ids = example['input_ids'].tolist()
        start_idx = int(example['start_idx'])
        end_idx = int(example['end_idx'])
        ground_truth = tokenizer.decode(tokenizer.encode(example['answer']).ids)
        span = tokenizer.decode(input_ids[start_idx:end_idx+1])
        
        try:
            assert span == ground_truth
        except:
            error_indices.append(i)
    
    if len(error_indices) == 0:
        print(f"Test passed successfully")
    else:
        print(f"Error indices: {error_indices}")
        return error_indices
        

In [14]:
error_indices = test_answers(dataset, tokenizer)
remove_indices = set(error_indices + long_examples)

Error indices: [13, 118, 185, 448, 509, 685, 712, 855, 1130, 1280, 1496, 1584, 1679, 1753, 1879, 2176, 2223, 2542, 2676, 2686, 2697, 2724, 2746, 2877, 3172, 3188, 3247, 3400, 3447, 3707, 3739, 3801, 3933, 4081, 4147, 4174, 4490, 4557, 4640, 4782, 4846, 4892, 5062, 5185, 5282, 5410, 5411, 5515, 5862, 5929, 6398]


In [15]:
clean_dataset = []
for i in range(len(dataset)):
    if i not in remove_indices:
        clean_dataset.append(dataset[i])
    
assert len(clean_dataset) == len(dataset) - len(remove_indices)

In [16]:
train_dataset, valid_dataset = train_test_split(clean_dataset)

In [17]:
train_loader = DataLoader(train_dataset, 32)
valid_loader = DataLoader(valid_dataset, 32)

In [18]:
for batch in train_loader:
    pass

for batch in valid_loader:
    pass

In [19]:
len(train_dataset), len(valid_dataset)

(4687, 1563)

In [20]:
device = torch.device('cuda')
base_model = AutoModel.from_pretrained('hi-lm-distilbert/').to(device)
for param in base_model.parameters():
    param.requires_grad = False

In [21]:
class BertQA(nn.Module):
    
    def __init__(self, base_model):
        
        super().__init__()
        
        self.base_model = base_model
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(768, 2)
    
    def forward(self, input_ids, mask, token_type_ids):
        
        with torch.no_grad():
            sequence_output = self.base_model(input_ids=input_ids, attention_mask=mask)[0]
        # sequence_output = [bs, num_tokens, 768]
        
        linear_out = self.linear(self.dropout(sequence_output))
        # [bs, num_tokens, 2]
        
        start_scores, end_scores = linear_out.split(1, dim=-1)
        # start_scores = [bs, num_tokens, 1]
        
        start_scores = start_scores.squeeze(-1)
        # [bs, num_tokens]
        
        end_scores = end_scores.squeeze(-1)
        # [bs, num_tokens]
        
        return start_scores, end_scores

In [40]:
model = BertQA(base_model).to(device)

In [41]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, verbose=True)

In [42]:
def train(model, loader):
    
    model.train()
    train_loss = 0.
    
    for bi, batch in enumerate(loader):
        
        if bi % 30 == 0:
            print(f"Starting batch: {bi}")
            
        
        input_ids = batch['input_ids'].to(device)
        mask = batch['mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        start_idx = batch['start_idx'].to(device)
        end_idx = batch['end_idx'].to(device)
        
        p1, p2 = model(input_ids, mask, token_type_ids)
        
        loss = F.cross_entropy(p1, start_idx) + F.cross_entropy(p2, end_idx)
        
        train_loss += loss.item()
        
        loss.backward()
        
        optimizer.step()
        
        optimizer.zero_grad()
    
    
    return train_loss/len(loader)
        
        
        

In [43]:
def validate(model, loader):
    
    model.eval()
    valid_loss = 0.
    em, f1 = 0., 0.
    predictions = {}
    
    for bi, batch in enumerate(loader):
        
        if bi % 30 == 0:
            print(f"Starting batch: {bi}")
            
        
        input_ids = batch['input_ids'].to(device)
        mask = batch['mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        start_idx = batch['start_idx'].to(device)
        end_idx = batch['end_idx'].to(device)
        context = batch['context']
        question_id = batch['question_id']
        offsets = batch['offsets']
        
        with torch.no_grad():
            
            p1, p2 = model(input_ids, mask, token_type_ids)

            loss = F.cross_entropy(p1, start_idx) + F.cross_entropy(p2, end_idx)

            valid_loss += loss.item()
            
            batch_size = p1.shape[0]
            starts = torch.argmax(torch.softmax(p1, dim=1), dim=1)
            ends = torch.argmax(torch.softmax(p2, dim=1), dim=1)

#             batch_size, c_len = p1.size()
#             ls = nn.LogSoftmax(dim=1)
#             mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
#             score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
#             score, s_idx = score.max(dim=1)
#             score, e_idx = score.max(dim=1)
#             s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            
            for i in range(batch_size):
                q_id = question_id[i]
                start = starts[i]
                end = ends[i]
                pred = ""
                
                if end < start:
                    end = start
                    
                for ix in range(start, end+1):
                    pred += context[i][offsets[i][ix][0]:offsets[i][ix][1]]
                    if (ix+1) < len(offsets[i]) and offsets[i][ix][1] < offsets[i][ix+1][0]:
                        pred += " "
                
                predictions[q_id] = pred
                
            
            
            
            

            
#             for i in range(batch_size):
#                 q_id = question_id[i]
#                 start = starts[i]
#                 end = ends[i]
#                 pred = input_ids[i][start: end+1]
#                 pred = tokenizer.decode(pred.tolist())
#                 predictions[q_id] = pred

    em, f1 = get_metrics(predictions)
    return valid_loss/len(loader), em, f1
  
        

In [44]:
def get_metrics(predictions):
    f1 = exact_match = total = 0
    for i in range(len(valid_dataset)):
        question_id = valid_dataset[i]['question_id']
        ground_truth = valid_dataset[i]['answer']
        prediction = predictions[question_id]
        total += 1
        f1 += f1_score(prediction, ground_truth)
        exact_match += exact_match_score(prediction, ground_truth)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return exact_match, f1

def normalize_answer(s):
    
    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def normalize(text):
        return tokenizer.decode(tokenizer.encode(text).ids)

    return white_space_fix(remove_punc(normalize(s)))

def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

import time
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:

train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    start_time = time.time()
    
    train_loss = train(model, train_loader)
    valid_loss, em, f1 = validate(model, valid_loader)
    #scheduler.step(valid_loss)
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)
    
    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch EM: {em}")
    print(f"Epoch F1: {f1}")
    print("====================================================================================")
    

Epoch 1
Starting batch: 0
Starting batch: 30
Starting batch: 60
Starting batch: 90
Starting batch: 120
Starting batch: 0
Starting batch: 30
Epoch train loss : 8.39178753872307| Time: 6m 49s
Epoch valid loss: 7.829391440566705
Epoch EM: 3.710812539987204
Epoch F1: 10.26906456372915
Epoch 2
Starting batch: 0
Starting batch: 30
Starting batch: 60
Starting batch: 90
Starting batch: 120
Starting batch: 0


In [212]:
mask = torch.ones(c_len, c_len) * float('-inf')

In [213]:
mask = mask.tril(-1)

In [223]:
print(mask.shape)
mask

torch.Size([100, 100])


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [-inf, 0., 0.,  ..., 0., 0., 0.],
        [-inf, -inf, 0.,  ..., 0., 0., 0.],
        ...,
        [-inf, -inf, -inf,  ..., 0., 0., 0.],
        [-inf, -inf, -inf,  ..., -inf, 0., 0.],
        [-inf, -inf, -inf,  ..., -inf, -inf, 0.]])

In [218]:
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])

In [221]:
a = a.tril(-1)

In [222]:
a

tensor([[0, 0, 0],
        [4, 0, 0],
        [7, 8, 0]])

In [224]:
mask = mask.unsqueeze(0)
mask.shape

torch.Size([1, 100, 100])

In [225]:
mask = mask.expand(32,-1,-1)

In [226]:
mask.shape

torch.Size([32, 100, 100])

In [229]:
mask[0].shape

torch.Size([100, 100])

In [230]:
ls = nn.LogSoftmax(dim=1)

In [232]:
p1

tensor([[ 0.4137, -0.4797,  0.8668,  ...,  1.7687,  0.0290, -1.5110],
        [-1.3654, -0.1342,  1.8637,  ...,  1.4508,  1.2209, -1.5940],
        [ 0.3264,  1.9312,  1.4622,  ...,  2.5839,  0.6189, -1.5879],
        ...,
        [-1.2825,  0.6174, -0.7114,  ...,  0.2471,  0.6947, -0.6437],
        [-1.0260,  0.6079,  0.3940,  ...,  0.1906, -0.6009, -1.7763],
        [-0.7347,  1.3133,  0.2314,  ...,  0.3422,  0.3730,  1.1547]])

In [260]:
ls(p1)

tensor([[-4.7223, -5.6158, -4.2692,  ..., -3.3674, -5.1071, -6.6470],
        [-6.4513, -5.2202, -3.2223,  ..., -3.6351, -3.8651, -6.6799],
        [-4.7286, -3.1237, -3.5928,  ..., -2.4711, -4.4361, -6.6429],
        ...,
        [-6.3158, -4.4160, -5.7447,  ..., -4.7863, -4.3386, -5.6771],
        [-5.9430, -4.3091, -4.5230,  ..., -4.7263, -5.5179, -6.6933],
        [-5.7874, -3.7394, -4.8213,  ..., -4.7106, -4.6797, -3.8980]])

In [294]:
l1 = ls(p1).unsqueeze(2)
l1[0]

tensor([[-4.7223],
        [-5.6158],
        [-4.2692],
        [-6.0345],
        [-5.4311],
        [-7.2324],
        [-6.2206],
        [-4.5669],
        [-4.3774],
        [-3.2788],
        [-4.6844],
        [-5.6938],
        [-6.2352],
        [-4.5848],
        [-5.0405],
        [-6.6577],
        [-5.2806],
        [-5.0867],
        [-4.6085],
        [-5.9321],
        [-3.8752],
        [-5.7028],
        [-3.6106],
        [-4.8792],
        [-4.4604],
        [-4.1428],
        [-4.2659],
        [-5.4180],
        [-4.8070],
        [-4.3509],
        [-4.3480],
        [-3.6104],
        [-4.9513],
        [-5.4603],
        [-7.0213],
        [-3.8093],
        [-6.5016],
        [-6.0616],
        [-3.3723],
        [-5.3915],
        [-7.1205],
        [-4.6379],
        [-4.5171],
        [-5.8253],
        [-2.6023],
        [-5.2271],
        [-4.6759],
        [-4.7474],
        [-5.8019],
        [-5.3410],
        [-4.4021],
        [-3.1882],
        [-5.

In [290]:
l2 = ls(p2).unsqueeze(1)
l2[0]

tensor([[-5.8925, -5.8321, -5.6961, -4.1644, -4.0735, -4.3292, -6.7989, -5.9053,
         -4.4028, -5.0780, -2.4458, -6.1392, -6.3915, -6.0832, -5.2859, -5.1156,
         -5.8976, -5.4570, -4.6064, -4.3223, -4.2313, -3.5619, -3.5639, -4.9775,
         -5.5472, -6.9181, -4.2203, -3.5374, -5.7007, -5.3400, -3.4137, -4.0631,
         -5.1530, -4.7224, -3.1887, -4.2997, -6.6982, -4.8285, -4.8412, -4.1327,
         -6.8859, -4.8507, -3.6329, -3.7325, -5.1929, -5.6284, -3.8776, -6.4303,
         -4.9868, -4.9549, -5.5869, -6.9362, -6.2101, -4.9045, -4.6439, -5.1715,
         -6.2859, -5.4918, -3.2712, -4.2374, -4.4704, -4.9609, -4.7038, -6.3677,
         -3.2878, -5.8089, -5.3367, -5.0135, -4.2300, -4.8668, -5.5809, -7.1074,
         -5.3295, -4.5600, -6.8262, -5.3395, -4.9323, -4.1471, -3.8173, -5.0625,
         -5.3152, -5.3000, -3.9520, -6.1997, -5.9521, -6.6265, -5.9507, -4.8792,
         -5.0215, -4.5036, -5.6782, -4.1145, -6.9351, -6.6818, -4.3470, -7.1892,
         -3.9099, -6.9267, -

In [292]:
(l1 + l2)[0]

tensor([[-10.6149, -10.5544, -10.4184,  ..., -11.6490,  -9.8740, -10.5760],
        [-11.5083, -11.4478, -11.3118,  ..., -12.5424, -10.7674, -11.4694],
        [-10.1618, -10.1013,  -9.9653,  ..., -11.1959,  -9.4209, -10.1229],
        ...,
        [ -9.2599,  -9.1994,  -9.0634,  ..., -10.2940,  -8.5190,  -9.2210],
        [-10.9996, -10.9391, -10.8031,  ..., -12.0337, -10.2587, -10.9607],
        [-12.5395, -12.4791, -12.3431,  ..., -13.5737, -11.7987, -12.5007]])

In [265]:
score = ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)

In [240]:
score.shape

torch.Size([32, 100, 100])

In [266]:
score = score + mask

In [267]:
score

tensor([[[-10.6149, -10.5544, -10.4184,  ..., -11.6490,  -9.8740, -10.5760],
         [    -inf, -11.4478, -11.3118,  ..., -12.5424, -10.7674, -11.4694],
         [    -inf,     -inf,  -9.9653,  ..., -11.1959,  -9.4209, -10.1229],
         ...,
         [    -inf,     -inf,     -inf,  ..., -10.2940,  -8.5190,  -9.2210],
         [    -inf,     -inf,     -inf,  ...,     -inf, -10.2587, -10.9607],
         [    -inf,     -inf,     -inf,  ...,     -inf,     -inf, -12.5007]],

        [[-10.9911, -10.9528, -10.6235,  ..., -12.9481, -10.5561, -13.5723],
         [    -inf,  -9.7217,  -9.3923,  ..., -11.7170,  -9.3249, -12.3412],
         [    -inf,     -inf,  -7.3944,  ...,  -9.7191,  -7.3270, -10.3433],
         ...,
         [    -inf,     -inf,     -inf,  ..., -10.1319,  -7.7399, -10.7561],
         [    -inf,     -inf,     -inf,  ...,     -inf,  -7.9698, -10.9861],
         [    -inf,     -inf,     -inf,  ...,     -inf,     -inf, -13.8009]],

        [[-10.7144,  -9.1789,  -8.9693,  ...

In [269]:
score, s_idx = score.max(dim=1)

In [270]:
score

tensor([[-10.6149, -10.5544,  -9.9653,  ...,  -9.5290,  -7.7540,  -8.4560],
        [-10.9911,  -9.7217,  -7.3944,  ...,  -9.3226,  -6.9306,  -9.9468],
        [-10.7144,  -7.5740,  -7.3644,  ...,  -7.7463,  -6.3472,  -5.8191],
        ...,
        [-11.5930,  -8.5869, -10.3159,  ...,  -8.6843,  -8.4279,  -7.7343],
        [-11.5702, -10.0084,  -9.2411,  ...,  -6.4838,  -7.4405,  -7.7328],
        [-12.1434,  -8.8563,  -6.9355,  ...,  -7.7213,  -6.1229,  -6.3269]])

In [253]:
s_idx.shape

torch.Size([32, 100])

In [248]:
score, e_idx = score.max(dim=1)

In [252]:
score.shape

torch.Size([32])

In [254]:
e_idx.shape

torch.Size([32])

In [273]:
a = torch.randn(2,3,3)

In [274]:
a

tensor([[[-1.1602, -0.3537, -0.1346],
         [-0.0267,  0.9523,  0.2527],
         [-0.0346, -0.8539,  0.8453]],

        [[-1.3406, -1.2280, -0.0061],
         [-1.0539,  1.2168, -0.8714],
         [-1.4040, -0.5714, -0.9305]]])

In [276]:
a.max(dim=0)

torch.return_types.max(values=tensor([[-1.1602, -0.3537, -0.0061],
        [-0.0267,  1.2168,  0.2527],
        [-0.0346, -0.5714,  0.8453]]), indices=tensor([[0, 0, 1],
        [0, 1, 0],
        [0, 1, 0]]))

In [277]:
a.max(dim=1)

torch.return_types.max(values=tensor([[-0.0267,  0.9523,  0.8453],
        [-1.0539,  1.2168, -0.0061]]), indices=tensor([[1, 1, 2],
        [1, 1, 0]]))

In [278]:
a.max(dim=2)

torch.return_types.max(values=tensor([[-0.1346,  0.9523,  0.8453],
        [-0.0061,  1.2168, -0.5714]]), indices=tensor([[2, 1, 2],
        [2, 1, 1]]))

In [279]:
a , s = a.max(dim=1)

In [280]:
a

tensor([[-0.0267,  0.9523,  0.8453],
        [-1.0539,  1.2168, -0.0061]])

In [26]:
m = torch.tensor([[1],[2],[3],[4]])
n = torch.tensor([[1,2,3,4]])

In [27]:
m.shape

torch.Size([4, 1])

In [28]:
n.shape

torch.Size([1, 4])

In [30]:
sum = m + n

In [38]:
a = torch.randn(2,3)
a

tensor([[-2.1659, -1.7538, -1.8445],
        [-1.2636, -0.9122, -0.0057]])

In [41]:
a0, s = a.max(dim=0)
a0

tensor(-0.0057)

In [42]:
a1, e = a.max(dim=1)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

In [40]:
a1

tensor([-1.7538, -0.0057])