# Birthdays probing test

final results, comparing my trained model vs. distilbert pretrained:

```
                   month    day
my model:             10    3.5 
distilbert:           25    8.0
random guessing:       8    3.2
```

In [1]:
import sys
sys.path.append('/home/jxm3/research/deidentification/unsupervised-deidentification')

In [None]:
from model import DocumentProfileMatchingTransformer

import os

num_cpus = os.cpu_count()

model = DocumentProfileMatchingTransformer.load_from_checkpoint(
    # distilbert-distilbert model
    #    '/home/jxm3/research/deidentification/unsupervised-deidentification/saves/distilbert-base-uncased__dropout_0.8_0.8/deid-wikibio_default/1irhznnp_130/checkpoints/epoch=25-step=118376.ckpt',
    # roberta-distilbert model
    # '/home/jxm3/research/deidentification/unsupervised-deidentification/saves/roberta__distilbert-base-uncased__dropout_0.8_0.8/deid-wikibio_default/1f7mlhxn_162/checkpoints/epoch=16-step=309551.ckpt',
    # roberta-distilbert model trained for longer
    '/home/jxm3/research/deidentification/unsupervised-deidentification/saves/roberta__distilbert-base-uncased__dropout_0.8_0.8/deid-wikibio_default/3nbt75gp_171/checkpoints/epoch=20-step=382387.ckpt',
    document_model_name_or_path='roberta-base',
    profile_model_name_or_path='distilbert-base-uncased',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
    learning_rate=1e-6,
    max_seq_length=256,
    pretrained_profile_encoder=False,
    word_dropout_ratio=0.0,
    word_dropout_perc=0.0,
    lr_scheduler_factor=0.5,
    lr_scheduler_patience=3,
    adversarial_mask_k_tokens=0,
    train_without_names=False,
)

In [78]:
from datamodule import WikipediaDataModule
import os

num_cpus = os.cpu_count()

dm = WikipediaDataModule(
    mask_token=model.document_tokenizer.mask_token,
    dataset_name='wiki_bio',
    dataset_train_split='train[:100%]',
    dataset_val_split='val[:20%]',
    dataset_version='1.2.0',
    num_workers=min(8, num_cpus),
    train_batch_size=64,
    eval_batch_size=64,
)
dm.setup("fit")

Initializing WikipediaDataModule with num_workers = 8 and mask token `<mask>`
loading wiki_bio[1.2.0] split train[:100%]


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)


loading wiki_bio[1.2.0] split val[:20%]


Using custom data configuration default
Reusing dataset wiki_bio (/home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da)
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-58e5e96e220311ed.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-778e9a6d1b0dfab7.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-3c4e94260fbd4dd3.arrow
Loading cached processed dataset at /home/jxm3/.cache/huggingface/datasets/wiki_bio/default/1.2.0/c05ce066e9026831cd7535968a311fc80f074b58868cfdffccbc811dff2ab6da/cache-9e279afc7bfb46f2.arrow
Loading cached processed dataset at /h

## Get the birthday data

In [79]:
import datetime

d = datetime.datetime.strptime('17 january 1943', "%d %B %Y")
d.day

17

In [160]:
from typing import List, Tuple

from tqdm.notebook import tqdm

import datetime
import re


def process_dataset(_dataset) -> List[Tuple[int, int]]:
    _processed_data = []
    for idx, d in enumerate(tqdm(_dataset, 'processing birthdays')):
        profile = d['profile']
        date_str_matches = re.search(r"birth_date \| ([\d]{1,4} [a-z]+ [\d]{1,4})", profile)
        if date_str_matches:
            date_str = date_str_matches.group(1)
            # print(date_str)
            # parse to datetime.datetime
            try:
                dt = datetime.datetime.strptime(date_str, "%d %B %Y")
            except ValueError as e:
                # print(e)
                continue
            # day_class_num = (dt.month - 1) * 31 + (dt.day - 1)
            # _processed_data.append((idx, day_class_num))
            _processed_data.append((idx, dt.month-1, dt.day-1))
    return _processed_data

## Create birthday data module

In [161]:
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

num_cpus = os.cpu_count()

class BirthdayDataModule(LightningDataModule):
    train_dataset: List[Tuple[int, int, int]]
    val_dataset: List[Tuple[int, int, int]]
    batch_size: int
    def __init__(self, dm: WikipediaDataModule, batch_size: int = 64):
        super().__init__()
        self.train_dataset = process_dataset(dm.train_dataset)
        self.val_dataset = process_dataset(dm.val_dataset)
        self.batch_size = batch_size
        self.num_workers = min(4, num_cpus)

    def setup(self, stage: str) -> None:
        return

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False # Only shuffle for train
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False
        )


In [162]:
birthday_dm = BirthdayDataModule(dm)

processing birthdays:   0%|          | 0/582659 [00:00<?, ?it/s]

processing birthdays:   0%|          | 0/14566 [00:00<?, ?it/s]

## Create birthday model

In [93]:
import numpy as np

def precompute_embeddings(model: DocumentProfileMatchingTransformer, datamodule: WikipediaDataModule):
    model.profile_model.cuda()
    model.profile_model.eval()
    print('Precomputing profile embeddings before first epoch...')
    
    model.train_profile_embeddings = np.zeros((len(datamodule.train_dataset), model.profile_embedding_dim))
    for train_batch in tqdm(datamodule.train_dataloader(), desc="[1/2] Precomputing train embeddings - profile", colour="cyan", leave=False):
        with torch.no_grad():
            profile_embeddings = model.forward_profile_text(text=train_batch["profile"])
        model.train_profile_embeddings[train_batch["text_key_id"]] = profile_embeddings.cpu()
    model.train_profile_embeddings = torch.tensor(model.train_profile_embeddings, dtype=torch.float32)
    
    model.val_profile_embeddings = np.zeros((len(datamodule.val_dataset), model.profile_embedding_dim))
    for val_batch in tqdm(datamodule.val_dataloader(), desc="[2/2] Precomputing val embeddings - profile", colour="green", leave=False):
        with torch.no_grad():
            profile_embeddings = model.forward_profile_text(text=val_batch["profile"])
        model.val_profile_embeddings[val_batch["text_key_id"]] = profile_embeddings.cpu()
    model.val_profile_embeddings = torch.tensor(model.val_profile_embeddings, dtype=torch.float32)
    
    
    model.profile_model.train()

precompute_embeddings(model, dm)

Precomputing profile embeddings before first epoch...


[1/2] Precomputing train embeddings - profile:   0%|          | 0/9105 [00:00<?, ?it/s]

[2/2] Precomputing val embeddings - profile:   0%|          | 0/228 [00:00<?, ?it/s]

In [189]:
from typing import Dict

import torch
import torchmetrics
import transformers

from pytorch_lightning import LightningModule
from transformers import AdamW

class BirthdayModel(LightningModule):
    """Probes the PROFILE for birthday info."""
    profile_embeddings: torch.Tensor
    classifier: torch.nn.Module
    learning_rate: float
    
    def __init__(self, model: DocumentProfileMatchingTransformer, learning_rate: float):
        super().__init__()
        # We can pre-calculate these embeddings bc
        self.train_profile_embeddings = torch.tensor(model.train_profile_embeddings.cpu())
        self.val_profile_embeddings = torch.tensor(model.val_profile_embeddings.cpu())
        self.month_classifier = torch.nn.Sequential(
            torch.nn.Linear(model.profile_embedding_dim, 64),
            # torch.nn.Dropout(p=0.01),
            torch.nn.Linear(64, 12),
        )
        self.day_classifier = torch.nn.Sequential(
            torch.nn.Linear(model.profile_embedding_dim, 64),
            # torch.nn.Dropout(p=0.01),
            torch.nn.Linear(64, 31),
        )
        self.learning_rate = learning_rate
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy   = torchmetrics.Accuracy()
        self.loss_criterion = torch.nn.CrossEntropyLoss()

    def training_step(self, batch: Tuple[int, int], batch_idx: int) -> torch.Tensor:
        profile_idxs, months, days = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.train_profile_embeddings))).all()
        assert ((0 <= months) & (months < 12)).all()
        assert ((0 <= days) & (days < 31)).all()
        
        clf_device = next(self.month_classifier.parameters()).device
        with torch.no_grad():
            embedding = self.train_profile_embeddings[profile_idxs].to(clf_device)
        
        
        month_logits = self.month_classifier(embedding)
        day_logits = self.day_classifier(embedding)
        
        
        month_loss = torch.nn.functional.cross_entropy(month_logits, months)
        day_loss = torch.nn.functional.cross_entropy(day_logits, days)
        
        self.log('train_acc_month', self.train_accuracy(month_logits, months))
        self.log('train_acc_day', self.train_accuracy(day_logits, days))
        
        if batch_idx == 0:
            print('train_acc_month', self.train_accuracy(month_logits, months))
            print('train_acc_day', self.train_accuracy(day_logits, days))
        
        return (month_loss + day_loss)
    
    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int):
        profile_idxs, months, days = batch
        assert ((0 <= profile_idxs) & (profile_idxs < len(self.val_profile_embeddings))).all()
        assert ((0 <= months) & (months < 12)).all()
        assert ((0 <= days) & (days < 31)).all()
        
        clf_device = next(self.month_classifier.parameters()).device
        with torch.no_grad():
            embedding = self.val_profile_embeddings[profile_idxs].to(clf_device)
        
        
        month_logits = self.month_classifier(embedding)
        day_logits = self.day_classifier(embedding)
        
        
        month_loss = torch.nn.functional.cross_entropy(month_logits, months)
        day_loss = torch.nn.functional.cross_entropy(day_logits, days)
        
        self.log('val_acc_month', self.val_accuracy(month_logits, months))
        self.log('val_acc_day', self.val_accuracy(day_logits, days))
        
        if batch_idx == 0:
            print('val_acc_month', self.val_accuracy(month_logits, months))
            print('val_acc_day', self.val_accuracy(day_logits, days))

        return (month_loss + day_loss)

    def configure_optimizers(self):
        """Prepare optimizer and schedule (linear warmup and decay)"""
        optimizer = AdamW(
            list(self.parameters()), lr=self.learning_rate
        )
        return optimizer
            

## Train it

In [167]:
from pytorch_lightning import Trainer, seed_everything

seed_everything(42)

num_validations_per_epoch = 4

Global seed set to 42


In [190]:
birthday_model = BirthdayModel(model, 1e-3)
birthday_dm.batch_size = 512

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
trainer = Trainer(
    default_root_dir=f"saves/jup/birthday_probing",
    val_check_interval=1.0,
    max_epochs=25,
    log_every_n_steps=50,
    gpus=torch.cuda.device_count(),
)

  self.train_profile_embeddings = torch.tensor(model.train_profile_embeddings.cpu())
  self.val_profile_embeddings = torch.tensor(model.val_profile_embeddings.cpu())
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [191]:
trainer.fit(birthday_model, birthday_dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Set SLURM handle signals.

  | Name             | Type             | Params
------------------------------------------------------
0 | month_classifier | Sequential       | 50.0 K
1 | day_classifier   | Sequential       | 51.2 K
2 | train_accuracy   | Accuracy         | 0     
3 | val_accuracy     | Accuracy         | 0     
4 | loss_criterion   | CrossEntropyLoss | 0     
------------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.405     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


val_acc_month tensor(0.0938, device='cuda:0')
val_acc_day tensor(0.0273, device='cuda:0')


Training: 0it [00:00, ?it/s]

train_acc_month tensor(0.0977, device='cuda:0')
train_acc_day tensor(0.0371, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0957, device='cuda:0')
val_acc_day tensor(0.0312, device='cuda:0')


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2f78ec29d0>
Traceback (most recent call last):
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    

train_acc_month tensor(0.0762, device='cuda:0')
train_acc_day tensor(0.0430, device='cuda:0')


assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2f78ec29d0>
Traceback (most recent call last):
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers
    if w.is_alive():
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2f78ec29d0>
Traceback (most recent call last):
  File "/home/jxm3/.conda/envs/textattack/lib/python3.9/site-packages/torch/utils/data/datalo

Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0996, device='cuda:0')
val_acc_day tensor(0.0293, device='cuda:0')
train_acc_month tensor(0.0781, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0996, device='cuda:0')
val_acc_day tensor(0.0234, device='cuda:0')
train_acc_month tensor(0.0723, device='cuda:0')
train_acc_day tensor(0.0430, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.1016, device='cuda:0')
val_acc_day tensor(0.0312, device='cuda:0')
train_acc_month tensor(0.0801, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0977, device='cuda:0')
val_acc_day tensor(0.0195, device='cuda:0')
train_acc_month tensor(0.0801, device='cuda:0')
train_acc_day tensor(0.0352, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0977, device='cuda:0')
val_acc_day tensor(0.0273, device='cuda:0')
train_acc_month tensor(0.0801, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0996, device='cuda:0')
val_acc_day tensor(0.0312, device='cuda:0')
train_acc_month tensor(0.0742, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0996, device='cuda:0')
val_acc_day tensor(0.0352, device='cuda:0')
train_acc_month tensor(0.0781, device='cuda:0')
train_acc_day tensor(0.0352, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.1016, device='cuda:0')
val_acc_day tensor(0.0332, device='cuda:0')
train_acc_month tensor(0.0781, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.1016, device='cuda:0')
val_acc_day tensor(0.0273, device='cuda:0')
train_acc_month tensor(0.0820, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0957, device='cuda:0')
val_acc_day tensor(0.0273, device='cuda:0')
train_acc_month tensor(0.0801, device='cuda:0')
train_acc_day tensor(0.0391, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0957, device='cuda:0')
val_acc_day tensor(0.0293, device='cuda:0')
train_acc_month tensor(0.0762, device='cuda:0')
train_acc_day tensor(0.0410, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0918, device='cuda:0')
val_acc_day tensor(0.0293, device='cuda:0')
train_acc_month tensor(0.0703, device='cuda:0')
train_acc_day tensor(0.0410, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0898, device='cuda:0')
val_acc_day tensor(0.0352, device='cuda:0')
train_acc_month tensor(0.0684, device='cuda:0')
train_acc_day tensor(0.0430, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0918, device='cuda:0')
val_acc_day tensor(0.0332, device='cuda:0')
train_acc_month tensor(0.0684, device='cuda:0')
train_acc_day tensor(0.0449, device='cuda:0')


Validating: 0it [00:00, ?it/s]

val_acc_month tensor(0.0898, device='cuda:0')
val_acc_day tensor(0.0352, device='cuda:0')
train_acc_month tensor(0.0703, device='cuda:0')
train_acc_day tensor(0.0449, device='cuda:0')


In [None]:
import torchmetrics

In [68]:
val_batch = next(iter(birthday_dm.val_dataloader()))

def do_validation_batch(batch, batch_idx):
    profile_idxs, birthday_idxs = batch
    clf_device = next(birthday_model.classifier.parameters()).device
    embedding = birthday_model.val_profile_embeddings[profile_idxs].to(clf_device)
    print('emebdding.shape:', embedding.shape)
    birthday_logits = birthday_model.classifier(embedding)
    print('birthday_logits.shape:', birthday_logits.shape)
    loss = torch.nn.functional.cross_entropy(
        birthday_logits, birthday_idxs
    )
    # self.log('val_accuracy', self.val_accuracy(birthday_logits, birthday_idxs))
    print('loss:', loss)

do_validation_batch(val_batch, 0)

emebdding.shape: torch.Size([64, 768])
birthday_logits.shape: torch.Size([64, 372])
loss: tensor(6.0617, grad_fn=<NllLossBackward0>)


In [125]:
val_batch # last element: idx 85, birthday 55

[tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 18, 19, 20, 21, 23,
         26, 28, 30, 31, 32, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 60, 61, 62, 63, 67, 68, 69, 70, 71, 72, 73,
         74, 76, 77, 78, 79, 80, 81, 82, 83, 85]),
 tensor([339,  75,  93, 131, 221, 331, 329, 334, 106, 241, 102, 219, 282, 129,
         206, 211,  14, 117, 170, 151, 101, 222, 232, 312, 347, 254,  36, 361,
          47, 207, 250, 212,  85, 272, 266, 204,  94,   4, 148, 141, 267, 325,
          87, 228, 371,  74, 285, 193,  48, 209, 126,  16,  21, 365, 183,  25,
         317, 247,  66,  74,  39,  58, 251,  55])]

In [126]:
55 % 31 # february 24th

dm.val_dataset[85]

{'document': "ben wilson (born 25 february 1977) is a former australian rules footballer who played with collingwood and the sydney swans in the australian football league (afl) .\nwilson was secured by collingwood from norwood in the 1994 afl draft with the ninth selection , but first not from a tac cup side .\nthe south australian did n't feature in the 1995 afl season and then appeared twice for collingwood in 1996 .\nhe was traded to sydney at the end of 1996 , along with mark orchard and two draft picks , for which collingwood received anthony rocca .\nhe played in the opening three rounds of the 1997 season but made only one further appearance .\n",
 'profile': "fullname | ben wilson\nname | ben wilson\noriginalteam | norwood\nyears | 1996 1997 '' ` total - '' '\ndraftpick | 9th , 1994 afl draft\nclubs | collingwood sydney swans\nbirth_date | 25 february 1977\narticle_title | ben wilson -lrb- australian footballer -rrb-\nheightweight | 191 ; kg & nbsp ; cm / 87 & nbsp\nstatsend |

In [137]:
birthday_model.val_profile_embeddings[85][:5]

tensor([ 0.0936,  0.3824,  0.6874,  0.7990, -0.5991])

In [138]:
model.eval()
model.forward_profile_text(text=[dm.val_dataset[85]['profile']])[0, :5]

tensor([ 0.0936,  0.3824,  0.6874,  0.7990, -0.5991], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [129]:
train_batch = next(iter(birthday_dm.train_dataloader()))
train_batch # 682, 78
78 % 31 # 16 -> this is march 17th

16

In [131]:
dm.train_dataset[682]['profile']

'nationalgoals | 12\nfullname | jesús candelas rodrigo\nmanagerclubs | netherlands assistant -rrb- iran netherlands malta thailand -lrb- assistant -rrb- hong kong malaysia netherlands -lrb-\nname | victor hermans\narticle_title | victor hermans\nnationalyears | 1977 -- 1989\nposition | manager -lrb- association football -rrb-\ncurrentclub | thailand national futsal team -lrb- head coach -rrb-\nclubs | mvv maastricht k.s.k. tongeren\nnationalteam | netherlands -lrb- futsal -rrb-\nbirth_place | maastricht , netherlands\nbirth_date | 17 march 1953\nnationalcaps | 50\nmanageryears | 1990 2000 2001 2001-2007 2009 -- 2011 2012 -- -- 1992 1992 -- 1996 1996 1997 --\nheight | 1.72'

In [132]:
birthday_model.train_profile_embeddings[682][:5]

tensor([-0.3974,  0.4090,  0.3919,  1.2626, -0.1960])

In [136]:
model.eval()
model.forward_profile_text(text=[dm.train_dataset[682]['profile']])[0, :5]

tensor([-0.3974,  0.4090,  0.3919,  1.2626, -0.1960], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [140]:
list(birthday_model.named_parameters())

[('0.weight',
  Parameter containing:
  tensor([[-0.0123,  0.0198, -0.0286,  ...,  0.0052, -0.0202,  0.0360],
          [-0.0107,  0.0232,  0.0180,  ...,  0.0116, -0.0154, -0.0274],
          [-0.0222, -0.0221,  0.0122,  ...,  0.0234,  0.0198,  0.0023],
          ...,
          [ 0.0127,  0.0177, -0.0266,  ..., -0.0159, -0.0071,  0.0111],
          [-0.0245,  0.0075,  0.0298,  ..., -0.0179, -0.0173,  0.0030],
          [ 0.0115,  0.0255,  0.0330,  ..., -0.0075, -0.0049, -0.0297]],
         device='cuda:0', requires_grad=True)),
 ('0.bias',
  Parameter containing:
  tensor([-0.0013,  0.0079,  0.0005, -0.0231,  0.0133, -0.0023,  0.0213, -0.0355,
          -0.0328, -0.0144, -0.0042,  0.0066, -0.0263, -0.0157,  0.0100,  0.0275,
           0.0136, -0.0305, -0.0026, -0.0168,  0.0358, -0.0242,  0.0104,  0.0301,
           0.0180,  0.0171,  0.0291,  0.0126,  0.0347,  0.0225,  0.0016, -0.0308,
           0.0349, -0.0179, -0.0320,  0.0195, -0.0254,  0.0104,  0.0150, -0.0162,
           0.0283, -