<a href="https://colab.research.google.com/github/colleenmacklin/GPTJ_Finetune_8bit/blob/main/finetune_8bit_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers==4.14.1 -q
#!pip install bitsandbytes-cuda111==0.26.0 -q
!pip install bitsandbytes-cuda113
!pip install bitsandbytes
!pip install datasets==1.16.1 -q

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Fine-tuning 6-Billion GPT-J (& other models) in colab with LoRA and 8-bit compression

This notebook is a simple example for fine-tuning [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) with limited memory. A detailed explanation of how it works can be found in [this model card](https://huggingface.co/hivemind/gpt-j-6B-8bit). It is heavily based on [this Colab](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es#scrollTo=vfdLQHOuEU7h). Huge thanks to Hivemind!

You can also finetune [GPT-Neo-2.7B](https://huggingface.co/gustavecortal/gpt-neo-2.7B-8bit), [French GPT-J (Cedille's Boris)](https://huggingface.co/gustavecortal/fr-boris-8bit) and [T0-3B](https://huggingface.co/gustavecortal/T0_3B-8bit) with limited memory.

Twitter: [@gustavecortal](https://twitter.com/gustavecortal)

In [2]:
from sklearn.model_selection import train_test_split

import transformers

import pandas as pd

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from tqdm.auto import tqdm

In [5]:
from google.colab import drive
#drive._mount("/content/drive")
drive.mount("/content/drive")

Mounted at /content/drive


## Converting the model to 8 bits

In [6]:
class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias).clone()
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr(
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

You have to Monkey-Patch GPT-J before loading: 

In [7]:
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock

If you're using another 8-bit quantized model (e.g. T0-3B), remember to Monkey-Patch the model using convert_to_int8()

In [8]:
class T5ForConditionalGeneration(transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)

transformers.models.t5.modeling_t5.T5ForConditionalGeneration = T5ForConditionalGeneration

In [9]:
config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Downloading:   0%|          | 0.00/930 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [10]:
config.pad_token_id = config.eos_token_id
tokenizer.pad_token = config.pad_token_id

In [11]:
gpt = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)
#gpt = GPTJForCausalLM.from_pretrained("gustavecortal/fr-boris-8bit", low_cpu_mem_usage=True) French GPT-J Cedille's Boris

Downloading:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.75G [00:00<?, ?B/s]

k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, 

## LoRA fine-tuning example

You can load my very small dataset composed of philosophical sentences: 

In [12]:
!gdown --id 1Q1WMjny26VHLKb71iTCHIS5zvdm9c-wz

Downloading...
From: https://drive.google.com/uc?id=1Q1WMjny26VHLKb71iTCHIS5zvdm9c-wz
To: /content/pgbp_example.csv
100% 41.0k/41.0k [00:00<00:00, 60.7MB/s]


In [13]:
data = pd.read_csv('/content/pgbp_example.csv')
data['sentence'] = 'Quote: ' + data['sentence']
train, test = train_test_split(data, test_size=0.01) 
train.to_csv('/content/train_pgbp_example.csv', index=False)
test.to_csv('/content/test_pgbp_example.csv', index=False)

In [14]:
from datasets import load_dataset
dataset = load_dataset('csv', data_files={'train': '/content/train_pgbp_example.csv',
                                              'test': '/content/test_pgbp_example.csv'})



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-5f2f404ed9008573/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-5f2f404ed9008573/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding=True, truncation=True, max_length= 128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence"])
tokenized_datasets.set_format("torch")



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [16]:
from torch.utils.data import DataLoader

full_train_dataset = tokenized_datasets["train"]
train_dataloader = DataLoader(full_train_dataset, shuffle=False, batch_size=8)

Add adapters to Embedding/MLP/Attention/LMHead layers

In [17]:
def add_adapters(model, adapter_dim=4, p = 0.1):
    assert adapter_dim > 0

    for name, module in model.named_modules():
      if isinstance(module, FrozenBNBLinear):
          if "attn" in name or "mlp" in name or "head" in name:
              print("Adding adapter to", name)
              module.adapter = nn.Sequential(
                nn.Linear(module.in_features, adapter_dim, bias=False),
                nn.Dropout(p=p),
                nn.Linear(adapter_dim, module.out_features, bias=False),
            )
              print("Initializing", name)
              nn.init.zeros_(module.adapter[2].weight)

          else:
              print("Not adding adapter to", name)
      elif isinstance(module, FrozenBNBEmbedding):
          print("Adding adapter to", name)
          module.adapter = nn.Sequential(
                nn.Embedding(module.num_embeddings, adapter_dim),
                nn.Dropout(p=p),
                nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
          print("Initializing", name)
          nn.init.zeros_(module.adapter[2].weight)

add_adapters(gpt)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpt.to(device)

Adding adapter to transformer.wte
Initializing transformer.wte
Adding adapter to transformer.h.0.attn.k_proj
Initializing transformer.h.0.attn.k_proj
Adding adapter to transformer.h.0.attn.v_proj
Initializing transformer.h.0.attn.v_proj
Adding adapter to transformer.h.0.attn.q_proj
Initializing transformer.h.0.attn.q_proj
Adding adapter to transformer.h.0.attn.out_proj
Initializing transformer.h.0.attn.out_proj
Adding adapter to transformer.h.0.mlp.fc_in
Initializing transformer.h.0.mlp.fc_in
Adding adapter to transformer.h.0.mlp.fc_out
Initializing transformer.h.0.mlp.fc_out
Adding adapter to transformer.h.1.attn.k_proj
Initializing transformer.h.1.attn.k_proj
Adding adapter to transformer.h.1.attn.v_proj
Initializing transformer.h.1.attn.v_proj
Adding adapter to transformer.h.1.attn.q_proj
Initializing transformer.h.1.attn.q_proj
Adding adapter to transformer.h.1.attn.out_proj
Initializing transformer.h.1.attn.out_proj
Adding adapter to transformer.h.1.mlp.fc_in
Initializing transfor

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): FrozenBNBEmbedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): FrozenBNBLinear(4096, 4096)
          (v_proj): FrozenBNBLinear(4096, 4096)
          (q_proj): FrozenBNBLinear(4096, 4096)
          (out_proj): FrozenBNBLinear(4096, 4096)
        )
        (mlp): GPTJMLP(
          (fc_in): FrozenBNBLinear(4096, 16384)
          (fc_out): FrozenBNBLinear(16384, 4096)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0

In [18]:
from bitsandbytes.optim import Adam8bit

gpt.gradient_checkpointing_enable()
optimizer = Adam8bit(gpt.parameters(), lr=1e-5, weight_decay=0.01)

In [19]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)

In [20]:
lr_scheduler = transformers.get_linear_schedule_with_warmup(
    optimizer, int(num_training_steps*0.1), num_training_steps
)

In [21]:
filepath = '/content/model.pt'

In [22]:
from tqdm.auto import tqdm

scaler = torch.cuda.amp.GradScaler()
progress_bar = tqdm(range(num_training_steps))
gpt.train()
gpt.gradient_checkpointing_enable()
k = 0

for epoch in range(num_epochs):
    for batch in train_dataloader:

        k = k + 1
        if k % 500 == 0:
          print(k)
          state = {'k' : k, 'epoch': num_epochs, 'lr_scheduler': lr_scheduler.state_dict(), 'state_dict': gpt.state_dict(), 'optimizer': optimizer.state_dict()}
          torch.save(state, filepath)

        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
          out = gpt.forward(**batch,)

          loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),
                                reduction='mean', label_smoothing=0.1)
          
        print(loss)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        lr_scheduler.step()
        progress_bar.update(1)

  0%|          | 0/195 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9087, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.7763, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.1863, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9792, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.5578, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9826, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.8027, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1897, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.2877, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.8544, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.1894, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.1877, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.6969, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1743, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.3146, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.6482, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.8271, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.4015, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.3020, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9816, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.4741, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0429, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9546, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.0520, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.5834, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1992, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1338, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.5388, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1843, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.8650, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.6312, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9371, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2537, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4894, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1032, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4602, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.5699, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4010, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.0203, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2531, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9396, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.2523, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.0170, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.6746, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.0306, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0683, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3737, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.2164, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.8867, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.0924, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1172, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.6670, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2817, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3526, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.6368, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.7329, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1400, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3245, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.8636, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1870, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.2883, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0671, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9132, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4914, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2062, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1572, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4538, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1799, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.7550, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.5110, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9714, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4163, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.5978, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0613, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3316, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4416, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2904, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.6979, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3946, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8955, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1479, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9504, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6965, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9420, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.2453, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4719, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0892, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8421, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9677, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9758, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6180, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3614, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3816, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.5851, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6185, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8427, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3288, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6831, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8554, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4613, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0990, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.7102, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3211, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0868, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0525, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.2529, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0536, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3914, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.1963, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.8272, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4417, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5482, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.8496, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0037, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0557, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.9225, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.1149, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3404, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5949, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.7422, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5898, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4171, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5899, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1576, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2704, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.6159, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4703, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5270, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4963, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2193, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1155, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1817, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2043, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2455, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2714, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.0236, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1987, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2317, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4794, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.8959, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1889, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.9071, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7915, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7741, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.8793, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7530, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.9158, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7806, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6364, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4153, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4437, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5813, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6510, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7616, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6572, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7372, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3468, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4656, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5014, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4676, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3380, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4411, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2559, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3039, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5044, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4480, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4638, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4401, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2559, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2064, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2697, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2915, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3096, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2987, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2241, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3123, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3533, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(5.8812, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2150, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4221, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1615, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1369, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1337, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1962, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1193, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2476, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1448, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.0493, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(5.9530, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(5.9793, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.0822, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1143, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2564, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1749, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1738, device='cuda:0', grad_fn=<AddBackward0>)


## Text generation example

In [25]:
gpt.eval()
with torch.no_grad():
  prompt = tokenizer("Quote:", truncation=True, padding=True, max_length=128, return_tensors='pt')
  prompt = {key: value.to(device) for key, value in prompt.items()}
  out = gpt.generate(**prompt, max_length=128, top_k=50, top_p=0.9, temperature=1.0, do_sample=True, repetition_penalty = 1.2, num_beams=1)
  print(tokenizer.decode(out[0]))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Quote: If we are to be true to the highest ideals for which our
nation has ever fought, let us resolve that today those rights and privileges shall be unsullied by invidious distinctions of race or of sex." - John F. Kennedy (Presidential Address before The American Society of Anesthesiologists)<|endoftext|>


In [32]:
gpt.save_pretrained("/content/pretrained_gptj")
#gpt.save_state_dict("/content/pretrained_gptj/state_dict")
torch.save(gpt.state_dict(), "/content/pretrained_gptj/state_dict")
