In [12]:
import torch
from transformers import get_constant_schedule_with_warmup
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import os
from dataloader import CoNLLReader
from tqdm import tqdm
from tutils import *

In [2]:
try:
    del globals()['NERmodelbase']
except:
    pass
from NERmodel import NERmodelbase

In [3]:
encoder_model = 'xlm-roberta-base'
tokenizer = AutoTokenizer.from_pretrained(encoder_model)

In [6]:
def collate_batch(batch):
    batch_ = list(zip(*batch))
    tokens, masks, token_masks, gold_spans, tags, lstm_encoded = batch_[0], batch_[1], batch_[2], batch_[3], \
                                                                 batch_[4], batch_[5]
    # print(tags)
    max_len = np.max([len(token) for token in tokens])
    # print(np.max([len(token) for token in tokens]), max_len)
    token_tensor = torch.empty(size=(len(tokens), max_len), dtype=torch.long).fill_(tokenizer.pad_token_id)
    tag_tensor = torch.empty(size=(len(tokens), max_len), dtype=torch.long).fill_(mconern['O'])
    mask_tensor = torch.zeros(size=(len(tokens), max_len), dtype=torch.bool)
    token_masks_tensor = torch.zeros(size=(len(tokens), max_len), dtype=torch.bool)
    lstm_encoded_tensor = torch.zeros(size=(len(tokens), max_len, 256), dtype=torch.float)
    # print(lstm_encoded.shape)
    for i in range(len(tokens)):
        tokens_ = tokens[i]
        seq_len = len(tokens_)

        token_tensor[i, :seq_len] = tokens_
        tag_tensor[i, :seq_len] = tags[i]
        mask_tensor[i, :seq_len] = masks[i]
        token_masks_tensor[i, :seq_len] = token_masks[i]
        lstm_encoded_tensor[i, 1:seq_len - 1, :] = lstm_encoded[i]

    return token_tensor, tag_tensor, mask_tensor, token_masks_tensor, gold_spans, lstm_encoded_tensor

In [8]:
def get_optimizer(net, opt=False):
    optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=0.03)
    if opt:
        scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEP)
        return [optimizer], [scheduler]
    return [optimizer]


In [9]:
NUM_EPOCH = 1
BATCH_SIZE = 64
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [16]:
# wnut_iob = {'B-CORP': 0, 'I-CORP': 1, 'B-CW': 2, 'I-CW': 3, 'B-GRP': 4, 'I-GRP': 5, 'B-LOC': 6, 'I-LOC': 7,
#             'B-PER': 8, 'I-PER': 9, 'B-PROD': 10, 'I-PROD': 11, 'O': 12}
mconern = indvidual(mconer_grouped, True)
reveremap = invert(mconer_grouped)

In [14]:
import pickle
if os.path.exists('train_load.pkl'):
    with open('train_load.pkl', 'rb') as f:
        ds = pickle.load(f)
else:
    print("reading from disk")
    ds = CoNLLReader(target_vocab=mconern, encoder_model=encoder_model, reversemap=reveremap, finegrained=fine)
    ds.read_data(data=r'C:\Users\Rah12937\PycharmProjects\mconer\multiconer2023\train_dev\en-train.conll')
    with open('train_load.pkl', 'wb') as f:
        pickle.dump(ds, f)

In [18]:
if os.path.exists('valid_load.pkl'):
    with open('valid_load.pkl', 'rb') as f:
        valid = pickle.load(f)
else:
    valid = CoNLLReader(target_vocab=mconern, encoder_model=encoder_model, reversemap=reveremap, finegrained=True)
    valid.read_data(data=r'C:\Users\Rah12937\PycharmProjects\mconer\multiconer2023\train_dev\en-dev.conll')
    with open('valid_load.pkl', 'wb') as f:
        pickle.dump(ds, f)

True
Reading file C:\Users\Rah12937\PycharmProjects\mconer\multiconer2023\train_dev\en-dev.conll


  x, y = self.lstm(torch.tensor(encoded))


Finished reading 871 instances from file C:\Users\Rah12937\PycharmProjects\mconer\multiconer2023\train_dev\en-dev.conll


In [20]:
model = NERmodelbase(tag_to_id=mconern, device=device, encoder_model=encoder_model, dropout=0.3, use_lstm=True).to(
    device)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  nn.init.uniform(self.w_omega, -0.1, 0.1)


In [21]:
trainloader = DataLoader(ds, batch_size=BATCH_SIZE, collate_fn=collate_batch, num_workers=0, shuffle=False)
validloader = DataLoader(valid, batch_size=BATCH_SIZE, collate_fn=collate_batch, num_workers=0)

In [22]:
WARMUP_STEP = int(len(trainloader) * NUM_EPOCH * 0.1)
print(f"Number of warm up step is {WARMUP_STEP}")
optim, scheduler = get_optimizer(model, True, warmup=WARMUP_STEP)

Number of warm up step is 26


In [24]:
from pytorchtools import EarlyStopping
import random
from tensorboardX import SummaryWriter
run_id = random.randint(1, 10000)
run_name = f"runid_{run_id}_EP_{NUM_EPOCH}_fine_xlm-b-birnnn-focal-loss-0_8-sep_lr-alpha-2-gama-4"
writer = SummaryWriter(run_name)
step = 0
running_loss = 0
early_stopping = EarlyStopping(patience=10, verbose=True, path=run_name + '.pt')

In [27]:
import numpy as np
eval_step = 100
for epoch in range(NUM_EPOCH):
    val_track = []
    with tqdm(trainloader, unit='batch') as tepoch:
        # model.train()
        tepoch.set_description(f"Epoch {epoch}")
        for i, data in enumerate(tepoch):
            optim[0].zero_grad()
            outputs, focal_loss = model(data)
            loss = 0.2 * outputs['loss'] + 0.8 * focal_loss
            running_loss += loss
            loss.backward()
            optim[0].step()
            scheduler[0].step()
            # if i % 10 == 0:  # print every 2000 mini-batches
            model.spanf1.reset()
            # writer.add_scalar('lr', scheduler[0].get_last_lr()[0], step)
            # run validation
            step += 1
            if (step + 1) % eval_step == 0:
                model.eval()
                with torch.no_grad():
                    with tqdm(validloader, unit='batch') as tepoch:
                        val_loss = 0
                        for i, data in enumerate(tepoch):
                            outputs, focal_loss = model(data, mode='predict')
                            val_loss += 0.2 * outputs['loss'] + 0.8 * focal_loss
                model.train()
                writer.add_scalars("Loss",
                                   {
                                       "Train Loss": round(running_loss.detach().cpu().numpy().ravel()[0] / 20, 4),
                                       "Valid Loss": round(val_loss.detach().cpu().numpy().ravel()[0] / 20, 4),
                                   }
                                   , step)
                writer.add_scalars("Metrics", outputs['results'], step)
                # print(outputs['results'])
                # writer.add_scalar("Loss/Test", round(val_loss.numpy()[0] / len(validloader), 4), step)
                # writer.add_scalar("Loss/Valid",
                #                   round(val_loss.detach().cpu().numpy().ravel()[0] / 20, 4), step)
                val_track.append(round(val_loss.detach().cpu().numpy().ravel()[0] / eval_step, 4))
                running_loss = 0

                early_stopping(round(val_loss.detach().cpu().numpy().ravel()[0] / eval_step, 4), model)
                if early_stopping.early_stop:
                    print("Stopping early")
                    break

writer.close()

Epoch 0:  37%|███████████████████████████████████████████████████████████                                                                                                     | 97/263 [01:04<01:50,  1.51batch/s]
  0%|                                                                                                                                                                                   | 0/14 [00:00<?, ?batch/s][A
  7%|████████████▏                                                                                                                                                              | 1/14 [00:00<00:07,  1.75batch/s][A
 14%|████████████████████████▍                                                                                                                                                  | 2/14 [00:00<00:05,  2.30batch/s][A
 21%|████████████████████████████████████▋                                                                                                         

Validation loss decreased (inf --> 1.135500).  Saving model ...


Epoch 0:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                        | 197/263 [02:19<00:46,  1.43batch/s]
  0%|                                                                                                                                                                                   | 0/14 [00:00<?, ?batch/s][A
  7%|████████████▏                                                                                                                                                              | 1/14 [00:00<00:06,  1.90batch/s][A
 14%|████████████████████████▍                                                                                                                                                  | 2/14 [00:00<00:05,  2.25batch/s][A
 21%|████████████████████████████████████▋                                                                                                         

Validation loss decreased (1.135500 --> 1.099600).  Saving model ...


Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 263/263 [03:14<00:00,  1.35batch/s]


In [28]:
# tokenizer.convert_ids_to_tokens(next(iter(trainloader))[0][0])

In [29]:
# next(iter(trainloader))

In [31]:
model.crf.transitions.cpu().detach().numpy().shape

(73, 73)

In [53]:
model.crf.viterbi_tags(torch.randn([2, 10, 73]))

[([22, 30, 2, 16, 12, 14, 64, 22, 44, 22], 19.867223739624023),
 ([50, 52, 8, 2, 24, 44, 45, 28, 64, 30], 22.552013397216797)]

In [248]:
z = torch.randn([1, 5, 73])

In [226]:
import allennlp.nn.util as util
def viterbi_tags(
    logits: torch.Tensor, mask: torch.BoolTensor = None, top_k: int = None,
    _constraint_mask= None, _transitions=None, start_transitions=None, end_transitions=None
):
    if mask is None:
        mask = torch.ones(*logits.shape[:2], dtype=torch.bool, device=logits.device)

    if top_k is None:
        top_k = 1
        flatten_output = True
    else:
        flatten_output = False

    _, max_seq_length, num_tags = logits.size()
#     print(num_tags)
    # Get the tensors out of the variables
    logits, mask = logits.data, mask.data

    # Augment transitions matrix with start and end transitions
    start_tag = num_tags
    end_tag = num_tags + 1
    transitions = torch.full((num_tags + 2, num_tags + 2), -10000.0, device=logits.device)

    # Apply transition constraints
    constrained_transitions = _transitions * _constraint_mask[:num_tags, :num_tags] + -10000.0 * (1 - _constraint_mask[:num_tags, :num_tags])
    transitions[:num_tags, :num_tags] = constrained_transitions.data

    if True:
        transitions[
            start_tag, :num_tags
        ] = start_transitions.detach() * _constraint_mask[
            start_tag, :num_tags
        ].data + -10000.0 * (
            1 - _constraint_mask[start_tag, :num_tags].detach()
        )
        transitions[:num_tags, end_tag] = end_transitions.detach() * _constraint_mask[
            :num_tags, end_tag
        ].data + -10000.0 * (1 - _constraint_mask[:num_tags, end_tag].detach())
    else:
        transitions[start_tag, :num_tags] = -10000.0 * (
            1 - self._constraint_mask[start_tag, :num_tags].detach()
        )
        transitions[:num_tags, end_tag] = -10000.0 * (
            1 - self._constraint_mask[:num_tags, end_tag].detach()
        )

    best_paths = []
    # Pad the max sequence length by 2 to account for start_tag + end_tag.
    tag_sequence = torch.empty(max_seq_length + 2, num_tags + 2, device=logits.device)

    for prediction, prediction_mask in zip(logits, mask):
        mask_indices = prediction_mask.nonzero(as_tuple=False).squeeze()
        masked_prediction = torch.index_select(prediction, 0, mask_indices)
        sequence_length = masked_prediction.shape[0]

        # Start with everything totally unlikely
        tag_sequence.fill_(-10000.0)
        # At timestep 0 we must have the START_TAG
        tag_sequence[0, start_tag] = 0.0
        # At steps 1, ..., sequence_length we just use the incoming prediction
        tag_sequence[1 : (sequence_length + 1), :num_tags] = masked_prediction
        # And at the last timestep we must have the END_TAG
        tag_sequence[sequence_length + 1, end_tag] = 0.0
#         print(torch.argmax(tag_sequence, dim=-1))

        # We pass the tags and the transitions to `viterbi_decode`.
        viterbi_paths, viterbi_scores = viterbi_decode(
            tag_sequence=tag_sequence[: (sequence_length + 2)],
            transition_matrix=transitions,
            top_k=top_k,
        )
#         print(viterbi_paths, viterbi_scores)
        top_k_paths = []
        for viterbi_path, viterbi_score in zip(viterbi_paths, viterbi_scores):
            # Get rid of START and END sentinels and append.
            viterbi_path = viterbi_path[1:-1]
            top_k_paths.append((viterbi_path, viterbi_score.item()))
        best_paths.append(top_k_paths)

    if flatten_output:
        return [top_k_paths[0] for top_k_paths in best_paths]

    return best_paths


In [253]:
from typing import Optional, List
def viterbi_decode(
    tag_sequence: torch.Tensor,
    transition_matrix: torch.Tensor,
    tag_observations: Optional[List[int]] = None,
    allowed_start_transitions: torch.Tensor = None,
    allowed_end_transitions: torch.Tensor = None,
    top_k: int = None,
):
    print(tag_sequence.shape)
    if top_k is None:
        top_k = 1
        flatten_output = True
    elif top_k >= 1:
        flatten_output = False
    else:
        raise ValueError(f"top_k must be either None or an integer >=1. Instead received {top_k}")

    sequence_length, num_tags = list(tag_sequence.size())

    has_start_end_restrictions = (
        allowed_end_transitions is not None or allowed_start_transitions is not None
    )

    if has_start_end_restrictions:

        if allowed_end_transitions is None:
            allowed_end_transitions = torch.zeros(num_tags)
        if allowed_start_transitions is None:
            allowed_start_transitions = torch.zeros(num_tags)

        num_tags = num_tags + 2
        new_transition_matrix = torch.zeros(num_tags, num_tags)
        new_transition_matrix[:-2, :-2] = transition_matrix

        # Start and end transitions are fully defined, but cannot transition between each other.

        allowed_start_transitions = torch.cat(
            [allowed_start_transitions, torch.tensor([-math.inf, -math.inf])]
        )
        allowed_end_transitions = torch.cat(
            [allowed_end_transitions, torch.tensor([-math.inf, -math.inf])]
        )

        # First define how we may transition FROM the start and end tags.
        new_transition_matrix[-2, :] = allowed_start_transitions
        # We cannot transition from the end tag to any tag.
        new_transition_matrix[-1, :] = -math.inf

        new_transition_matrix[:, -1] = allowed_end_transitions
        # We cannot transition to the start tag from any tag.
        new_transition_matrix[:, -2] = -math.inf

        transition_matrix = new_transition_matrix

    if tag_observations:
        if len(tag_observations) != sequence_length:
            raise ConfigurationError(
                "Observations were provided, but they were not the same length "
                "as the sequence. Found sequence of length: {} and evidence: {}".format(
                    sequence_length, tag_observations
                )
            )
    else:
        tag_observations = [-1 for _ in range(sequence_length)]

    if has_start_end_restrictions:
        tag_observations = [num_tags - 2] + tag_observations + [num_tags - 1]
        zero_sentinel = torch.zeros(1, num_tags)
        extra_tags_sentinel = torch.ones(sequence_length, 2) * -math.inf
        tag_sequence = torch.cat([tag_sequence, extra_tags_sentinel], -1)
        tag_sequence = torch.cat([zero_sentinel, tag_sequence, zero_sentinel], 0)
        sequence_length = tag_sequence.size(0)

    path_scores = []
    path_indices = []

    if tag_observations[0] != -1:
        one_hot = torch.zeros(num_tags)
        one_hot[tag_observations[0]] = 100000.0
        path_scores.append(one_hot.unsqueeze(0))
    else:
        path_scores.append(tag_sequence[0, :].unsqueeze(0))

    # Evaluate the scores for all possible paths.
#     print("length", path_scores[0].shape)
    all_score = []
    for timestep in range(1, sequence_length):
#         print(timestep, path_scores[timestep - 1].shape, transition_matrix.shape)
        # Add pairwise potentials to current scores.
        summed_potentials = path_scores[timestep - 1].unsqueeze(2) + transition_matrix
        summed_potentials = summed_potentials.view(-1, num_tags)
#         print(summed_potentials.shape, num_tags)
        # Best pairwise potential path score from the previous timestep.
        max_k = min(summed_potentials.size()[0], top_k)
        scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
        all_score.append(scores)
#         print("best paths", paths)
        # If we have an observation for this timestep, use it
        # instead of the distribution over tags.
        observation = tag_observations[timestep]
        # Warn the user if they have passed
        # invalid/extremely unlikely evidence.
        if tag_observations[timestep - 1] != -1 and observation != -1:
            if transition_matrix[tag_observations[timestep - 1], observation] < -10000:
                logger.warning(
                    "The pairwise potential between tags you have passed as "
                    "observations is extremely unlikely. Double check your evidence "
                    "or transition potentials!"
                )
        if observation != -1:
            one_hot = torch.zeros(num_tags)
            one_hot[observation] = 100000.0
            path_scores.append(one_hot.unsqueeze(0))
        else:
            path_scores.append(tag_sequence[timestep, :] + scores)
        path_indices.append(paths.squeeze())
#     print("Simple Argmax", torch.argmax(torch.cat(path_scores[1:-1], dim=0), dim=-1))
    print(torch.argmax(torch.cat(all_score, dim=0), dim=-1))
    # Construct the most likely sequence backwards.
    path_scores_v = path_scores[-1].view(-1)
#     print(path_scores[-1].shape)
    max_k = min(path_scores_v.size()[0], top_k)
    viterbi_scores, best_paths = torch.topk(path_scores_v, k=max_k, dim=0)
#     print(viterbi_scores, best_paths, max_k)
    viterbi_paths = []
    for i in range(max_k):
        viterbi_path = [best_paths[i]]
        for backward_timestep in reversed(path_indices):
            print(backward_timestep, viterbi_path[-1])
            viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))
        # Reverse the backward path.
        viterbi_path.reverse()

        if has_start_end_restrictions:
            viterbi_path = viterbi_path[1:-1]

        # Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
        viterbi_path = [j % num_tags for j in viterbi_path]
        viterbi_paths.append(viterbi_path)

    if flatten_output:
        return viterbi_paths[0], viterbi_scores[0]

    return viterbi_paths, viterbi_scores


In [254]:
viterbi_tags(logits=z, top_k=1,
             _constraint_mask= model.crf._constraint_mask, _transitions=model.crf.transitions,
             start_transitions=model.crf.start_transitions,
            end_transitions = model.crf.end_transitions)

torch.Size([7, 75])
tensor([52, 28, 74, 74, 74, 74])
tensor([30,  0, 30,  2, 30,  5, 30,  6, 30,  8, 30, 10, 30, 13, 30, 14, 30, 17,
        30, 18, 30, 21, 30, 22, 30, 24, 30, 26, 30, 28, 30, 30, 30, 32, 30, 34,
        30, 36, 30, 38, 30, 40, 30, 42, 30, 44, 30, 46, 30, 48, 30, 50, 30, 52,
        30, 54, 30, 57, 30, 58, 30, 60, 30, 62, 30, 64, 30, 66, 30, 68, 30, 71,
        30, 30, 57]) tensor(74)
tensor([71,  0, 71,  3, 10,  4, 71,  6, 10,  8, 10, 10, 10, 12, 71, 14, 10, 16,
        10, 19, 71, 20, 71, 22, 10, 24, 10, 26, 10, 28, 71, 30, 10, 32, 71, 34,
        10, 37, 71, 39, 71, 40, 10, 42, 10, 45, 10, 46, 10, 48, 71, 51, 10, 52,
        10, 54, 71, 56, 71, 58, 71, 60, 10, 62, 10, 64, 10, 66, 71, 68, 71, 71,
        71, 10, 45]) 57
tensor([50,  0, 50,  2, 70,  4, 70,  6, 70,  8, 70, 10, 70, 12, 50, 15, 50, 16,
        70, 18, 50, 20, 50, 22, 70, 24, 70, 26, 70, 28, 70, 30, 70, 32, 50, 34,
        50, 37, 50, 38, 70, 40, 50, 43, 50, 44, 70, 47, 50, 48, 50, 50, 70, 53,
        50,

[[([38, 52, 70, 56, 57], 11.473505973815918)]]

In [250]:
p = torch.randn([1,10]).unsqueeze(2)
q = torch.randn([10,10])
# print(p,q)

In [265]:
def _viterbi_decoding(emissions, transitions):
    scores = torch.zeros(emissions.size(1))
#     print(scores)
#     back_pointers = torch.zeros(emissions.size()).int()
#     scores = scores + emissions[0]
#     # Generate most likely scores and paths for each step in sequence
#     for i in range(1, emissions.size(0)):
#         scores_with_transitions = scores.unsqueeze(1).expand_as(transitions) + transitions
#         max_scores, back_pointers[i] = torch.max(scores_with_transitions, 0)
#         scores = emissions[i] + max_scores
#     # Generate the most likely path
#     viterbi = [scores.numpy().argmax()]
#     back_pointers = back_pointers.numpy()
#     for bp in reversed(back_pointers[1:]):
#         viterbi.append(bp[viterbi[-1]])
#     viterbi.reverse()
#     viterbi_score = scores.numpy().max()
#     return viterbi_score, viterbi

In [266]:
_viterbi_decoding(z, model.crf.transitions)

tensor([0., 0., 0., 0., 0.])


tensor([[-1.1792e+00, -2.7135e-02, -7.7828e-01,  1.0579e+00, -1.4069e+00,
          2.9101e-01,  1.3014e+00, -9.4263e-01, -7.3761e-01, -3.2752e-01,
          1.3143e+00, -9.0450e-01,  2.5484e-01, -1.7810e+00, -3.5424e-01,
          8.2577e-01,  1.2528e+00,  2.9811e-01, -8.1395e-01, -1.8555e+00,
         -8.4935e-01, -1.9918e+00,  6.0543e-02, -3.1424e-01, -3.9017e-01,
          2.4125e-01,  9.8717e-01,  1.6513e-01, -6.9897e-01, -1.1438e+00,
          5.5704e-01,  7.1476e-01,  1.2251e+00,  9.0660e-01,  4.3582e-02,
         -1.4489e+00, -1.5145e+00,  7.7617e-01,  2.4574e+00,  6.4148e-01,
          8.6552e-02, -6.3901e-01,  1.8131e+00, -1.7130e+00,  5.1753e-01,
          3.5789e-01, -1.4740e-01,  5.9639e-01,  9.0849e-01, -1.5356e+00,
         -1.1607e+00,  4.4085e-01, -1.9936e+00, -9.8936e-01, -3.0054e-01,
         -1.1908e+00, -4.1412e-01,  8.7369e-01,  6.1928e-01,  3.9566e-01,
         -1.5622e+00,  6.1427e-01, -3.1695e-01, -2.2818e+00,  8.9419e-01,
          4.0324e-01, -1.7241e+00, -6.

In [175]:
torch.topk((p+q).view(-1,10), k=1, dim=0)

torch.return_types.topk(
values=tensor([[3.1213, 0.9428, 0.6200, 1.2594, 1.3068, 1.4101, 0.7416, 1.1139, 1.8678,
         1.8825]]),
indices=tensor([[6, 3, 5, 5, 0, 6, 3, 6, 6, 5]]))

In [87]:
model.crf._constraint_mask[:73, :73] *  model.crf.transitions

tensor([[-0.1321,  0.1094, -0.1860,  ..., -0.2469,  0.0000, -0.0261],
        [ 0.1491,  0.2218,  0.0244,  ..., -0.0113,  0.0000, -0.0211],
        [ 0.0466,  0.0000,  0.0872,  ..., -0.2424, -0.0000, -0.3008],
        ...,
        [-0.1596, -0.0000, -0.1920,  ..., -0.2957,  0.2783, -0.0859],
        [ 0.0113,  0.0000,  0.0119,  ..., -0.1425,  0.0664, -0.1141],
        [-0.0490, -0.0000, -0.0082,  ..., -0.1509, -0.0000,  0.1871]],
       device='cuda:0', grad_fn=<MulBackward0>)