In [1]:
import os
import torch
import re
import ipdb
import torch.nn as nn

from typing import List,Dict,Literal,Tuple
from pathlib import Path
from pydantic import BaseModel
from transformers import GPT2Tokenizer,GPT2LMHeadModel
from torch.utils.data import Dataset,DataLoader
from transformers import AdamW,get_linear_schedule_with_warmup
from torch.optim import AdamW

## Config

In [2]:
class ConfigDataSet(BaseModel):
        split: Literal['train','dev','test']
        model_name: Literal['gpt2-large','gpt2-medium','gpt2','gpt2-xl'] = 'gpt2'
        trun_limit: int = 500    
        BASEPATH : Path = Path("../data/")
        device: Literal['cuda','cpu'] = "cuda" if torch.cuda.is_available() else "cpu"

## Dataset 

In [3]:
class EnronEmailDataset(Dataset):
    # Read About MRO(Method Resolution Order)   
    def __init__(
        self,
        config: ConfigDataSet
    ):
        # As Config is at as just data we can us it with pydatic
        self.config = config
        self.file_paths: List[str] = [ self.config.BASEPATH/self.config.split/name 
                                      for name in 
                                      os.listdir(self.config.BASEPATH/self.config.split)]
        
    def clean_text(
        self,
        text:str
    ):
        text = re.sub(' +',' ',text)
        text = re.sub('\n+','\n',text)
        text = re.sub('[^A-Za-z0-9\n\s\\/.-]+','',text)
        return text
        
    def __getitem__(
        self,
        idx:int
    ):
        
        """ 
        returns the input_ids and attention_maks also tuncates if
        the email is longer that what is specified in config
        """
        with open(self.file_paths[idx],'r') as f:
            email_with_subject = f.read().strip()
        
        email,subject = email_with_subject.split("@subject\n")
        
        email = self.clean_text(email)[:self.config.trun_limit]
        subject = "@subject\n"+ subject + " <|endoftext|>"
        # Token from which CLM will start Finetuning
        st_gen_token = len(email)        
        return (email+subject,st_gen_token)
        
    def __len__(
        self
    ):
        return len(self.file_paths)

In [4]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1505b3ad8820>

In [5]:
dataset_config = ConfigDataSet(split='train',trun_limit=500)
dataset = EnronEmailDataset(dataset_config)

In [6]:
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2').to(dataset_config.device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [7]:
optimizer = AdamW(gpt2.parameters(),
                  lr = 1e-5,
                  eps = 1e-8
                 )
criteria = nn.CrossEntropyLoss()

In [8]:
def train(optim,criteria):
    counter = 0
    for text,gen_id in dataset:
        for idx in range(gen_id,len(text)-1):
            inpt = tokenizer(text[:idx],return_tensors='pt').to(dataset_config.device)
            x = gpt2(**inpt)
            loss = criteria(x.logits[:,-1,:],target = tokenizer.encode(text[idx],\
                            return_tensors='pt').to(dataset_config.device)[0])
            loss.backward()
        counter+=1
        if counter%10==0:
            print(f"Data Points: {counter} Done")
        optim.step()

## Model 

In [9]:
train(optimizer,criteria)

Data Points: 10 Done
Data Points: 20 Done
Data Points: 30 Done
Data Points: 40 Done
Data Points: 50 Done
Data Points: 60 Done
Data Points: 70 Done
Data Points: 80 Done
Data Points: 90 Done
Data Points: 100 Done
Data Points: 110 Done
Data Points: 120 Done


KeyboardInterrupt: 

In [10]:
dataset[130]

('Lets eliminate future problems with Dow and their trader Mike Billings.\nCredit has given the O.K.\nto trade in the name of The Dow Chemical Company.\nWe will document trades on the omnibus UNLESS we are able to negotiate an ISDA with the counterparty.\nTell Mike Billings that we will not keep changing from Dow Hydrocarbons and Resources Inc. to the parent company.\nWe waste too much time and energy.\nCall if you have questions.\nSara\n@subject\nThe Dow Chemical Company <|endoftext|>',
 432)

In [23]:
out = tokenizer("Lets eliminate future problems with Dow and their trader Mike Billings.\
          \nCredit has given the O.K.\nto trade in the name of The Dow Chemical Company.\
          \nWe will document trades on the omnibus UNLESS we are able to negotiate an\
          ISDA with the counterparty.\nTell Mike Billings that we will not keep changing\
          from Dow Hydrocarbons and Resources Inc. to the parent company.\nWe waste too much\
          time and energy.\nCall if you have questions.\nSara",return_tensors='pt').to(dataset_config.device)

SyntaxError: EOL while scanning string literal (1024764421.py, line 1)

In [18]:
len(out['input_ids'][0])

150

In [19]:
x = gpt2.generate(**out,max_length=len(out['input_ids'][0])+20)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [24]:
print(tokenizer.decode(x[0]))

Lets eliminate future problems with Dow and their trader Mike Billings.          
Credit has given the O.K.
to trade in the name of The Dow Chemical Company.          
We will document trades on the omnibus UNLESS we are able to negotiate an          ISDA with the counterparty.
Tell Mike Billings that we will not keep changing          from Dow Hydrocarbons and Resources Inc. to the parent company.
We waste too much          time and energy.
Call if you have questions.
Sara <|endntennnnnnnnnnn


In [5]:
class ConfigModel(ConfigDataSet):
    gen_type: Literal['CLM',"GEN"] = "GEN"
    
class CausalLanguageModel(nn.Module):
    
    def __init__(
        self,
        config: ConfigModel
    ):
        super().__init__()
        self.config = config
        self.model = GPT2LMHeadModel.from_pretrained(self.config.model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.config.model_name)
        
    def forward(
        self,
        data:[str,int],
    ):
        
        gen_idx = data[1]
        if self.config.gen_type=="CLM":
            gen_idx = 1
        
        pred_logits = []
        for pt in range(gen_idx,len(data[0])):
            tok_inputs = self.tokenizer(data[0][:pt], return_tensors='pt').to(self.config.device)
            out = self.model(**tok_input)
            pred_logits.append(out.logits[:,-1,:])
            
        return torch.stack(pred_logits)

In [7]:
model_config = ConfigModel(split=dataset_config.split,trun_limit=dataset_config.trun_limit)

In [9]:
dataset_config.device

'cuda'

In [10]:
!nvidia-smi

Sat Oct 14 15:45:44 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.108.03   Driver Version: 510.108.03   CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 36%   27C    P0    75W / 250W |      3MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [11]:
# from transformers import DataCollatorForLanguageModeling

# dataloader = DataCollatorForLanguageModeling(dataset,

In [12]:
dataloader = DataLoader(dataset)

In [13]:
import gc
# del model
gc.collect()
torch.cuda.empty_cache()
model = CausalLanguageModel(model_config)
model.to(model_config.device)

CausalLanguageModel(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
)

In [14]:
gc.collect()
torch.cuda.empty_cache()

In [15]:
!nvidia-smi

Sat Oct 14 15:45:49 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.108.03   Driver Version: 510.108.03   CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 37%   28C    P2    74W / 250W |   1080MiB / 11264MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [16]:
from torch.optim import AdamW
epochs = 2
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

optimizer = AdamW(model.parameters(),
                  lr = learning_rate,
                  eps = epsilon
                )

In [17]:
total_steps = len(dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = warmup_steps, 
                                            num_training_steps = total_steps)

In [18]:
!nvidia-smi

Sat Oct 14 15:45:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.108.03   Driver Version: 510.108.03   CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 37%   28C    P2    74W / 250W |   1080MiB / 11264MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [19]:
out = next(iter(dataloader))

In [20]:
criteria = nn.CrossEntropyLoss()

In [21]:
device = dataset_config.device
for epoch_i in range(0, epochs):

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    
    model.train()
    
    for step, batch in enumerate(dataloader):
        ipdb.set_trace()
        input_ids = batch[0]['input_ids'].to(device)
        attention_mask = batch[0]['attention_mask'].to(device)
        gen_idx = batch[1]
        
        optimizer.zero_grad()        

        outputs = model(({'input_ids':input_ids,
                         'attention_mask': attention_mask,
                          }, gen_idx))

        ipdb.set_trace()
        # Preparing the Groud Truth
        gt = []
        output = []
        target = torch.ones(outputs.shape[0],dataset.tokenizer.vocab_size)
        target = target*(0.01/50257)
        
        for idx,idj in enumerate(input_ids[0][gen_idx:]):
            target[idx][idj]=0.99
        
        
        for idx,token_id in enumerate(input_ids[0][gen_idx:]):
            target = torch.ones(dataset.tokenizer.vocab_size)
            prob = 0.01/target.shape[0]
            target = target*prob
            target[token_id] = 0.99
            gt.append(target)
            output.append(nn.Softmax(outputs[idx]))
        ipdb.set_trace()
        loss = criteria(outputs,target)
        loss.backward()
        optimizer.step()
        scheduler.step()
        del outputs
        del target
        gc.collect()
        torch.cuda.empty_cache()

        if step % 1000 == 0 and not step == 0:

            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), batch_loss, elapsed))
            model.eval()
            sample_outputs = model.generate(
                                    bos_token_id=random.randint(1,30000),
                                    do_sample=True,   
                                    top_k=50, 
                                    max_length = 30,
                                    top_p=0.95, 
                                    num_return_sequences=1
                                )
            for i, sample_output in enumerate(sample_outputs):
                  print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))
            model.train()


Training...
> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(13)[0;36m<module>[0;34m()[0m
[0;32m     12 [0;31m        [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 13 [0;31m        [0minput_ids[0m [0;34m=[0m [0mbatch[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m        [0mattention_mask[0m [0;34m=[0m [0mbatch[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  continue


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(25)[0;36m<module>[0;34m()[0m
[0;32m     24 [0;31m        [0;31m# Preparing the Groud Truth[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m        [0mgt[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m        [0moutput[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(26)[0;36m<module>[0;34m()[0m
[0;32m     25 [0;31m        [0mgt[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 26 [0;31m        [0moutput[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m        [0mtarget[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mones[0m[0;34m([0m[0moutputs[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m[0mdataset[0m[0;34m.[0m[0mtokenizer[0m[0;34m.[0m[0mvocab_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(27)[0;36m<module>[0;34m()[0m
[0;32m     26 [0;31m        [0moutput[0m [0;34m=[0m [0;34m[[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 27 [0;31m        [0mtarget[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mones[0m[0;34m([0m[0moutputs[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m[0mdataset[0m[0;34m.[0m[0mtokenizer[0m[0;34m.[0m[0mvocab_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m        [0mtarget[0m [0;34m=[0m [0mtarget[0m[0;34m*[0m[0;34m([0m[0;36m0.01[0m[0;34m/[0m[0;36m50257[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(28)[0;36m<module>[0;34m()[0m
[0;32m     27 [0;31m        [0mtarget[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mones[0m[0;34m([0m[0moutputs[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m[0mdataset[0m[0;34m.[0m[0mtokenizer[0m[0;34m.[0m[0mvocab_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 28 [0;31m        [0mtarget[0m [0;34m=[0m [0mtarget[0m[0;34m*[0m[0;34m([0m[0;36m0.01[0m[0;34m/[0m[0;36m50257[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     29 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(30)[0;36m<module>[0;34m()[0m
[0;32m     29 [0;31m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m        [0;32mfor[0m [0midx[0m[0;34m,[0m[0midj[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0minput_ids[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0mgen_idx[0m[0;34m:[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m            [0mtarget[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m[[0m[0midj[0m[0;34m][0m[0;34m=[0m[0;36m0.99[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(31)[0;36m<module>[0;34m()[0m
[0;32m     30 [0;31m        [0;32mfor[0m [0midx[0m[0;34m,[0m[0midj[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0minput_ids[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0mgen_idx[0m[0;34m:[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 31 [0;31m            [0mtarget[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m[[0m[0midj[0m[0;34m][0m[0;34m=[0m[0;36m0.99[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     32 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_28813/2124851288.py[0m(30)[0;36m<module>[0;34m()[0m
[0;32m     29 [0;31m[0;34m[0m[0m
[0m[0;32m---> 30 [0;31m        [0;32mfor[0m [0midx[0m[0;34m,[0m[0midj[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0minput_ids[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[[0m[0mgen_idx[0m[0;34m:[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     31 [0;31m            [0mtarget[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m[[0m[0midj[0m[0;34m][0m[0;34m=[0m[0;36m0.99[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  criteria


CrossEntropyLoss()


ipdb>  loss = criteria(outputs,input_ids[0][gen_idx:])
ipdb>  loss


tensor(0.8558, device='cuda:0', grad_fn=<NllLossBackward0>)


ipdb>  loss.backward()


  File "/home2/sisodiya.bhoomendra/localpython/python3.9.16/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home2/sisodiya.bhoomendra/localpython/python3.9.16/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home2/sisodiya.bhoomendra/venvs/python3.9_global/lib/python3.9/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home2/sisodiya.bhoomendra/venvs/python3.9_global/lib/python3.9/site-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/home2/sisodiya.bhoomendra/venvs/python3.9_global/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 711, in start
    self.io_loop.start()
  File "/home2/sisodiya.bhoomendra/venvs/python3.9_global/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 215, in start
    self.asyncio_loop.run_forever()
  File "/home2/sisodiya.bhoomendra/localpython/p

*** RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.LongTensor [1, 194]] is at version 19; expected version 18 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!


ipdb>  exit


In [22]:
del outputs
del target
gc.collect()
torch.cuda.empty_cache()

In [25]:
!nvidia-smi

Sat Oct 14 14:52:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.108.03   Driver Version: 510.108.03   CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
| 36%   26C    P8    32W / 250W |   6760MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [26]:
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass



<class 'torch.nn.parameter.Parameter'> torch.Size([50257, 768])
<class 'torch.nn.parameter.Parameter'> torch.Size([1024, 768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([2304])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768])
<class 'torch.nn.parameter.Parameter'> torch.Size([768, 3072])
