In [None]:
!pip install transformers==4.14.1 -q
!pip install bitsandbytes-cuda111==0.26.0 -q
!pip install datasets==1.16.1 -q

[K     |████████████████████████████████| 3.4 MB 8.2 MB/s 
[K     |████████████████████████████████| 3.3 MB 40.2 MB/s 
[K     |████████████████████████████████| 596 kB 62.7 MB/s 
[K     |████████████████████████████████| 67 kB 5.8 MB/s 
[K     |████████████████████████████████| 895 kB 61.8 MB/s 
[K     |████████████████████████████████| 4.0 MB 9.6 MB/s 
[K     |████████████████████████████████| 298 kB 7.8 MB/s 
[K     |████████████████████████████████| 1.1 MB 60.0 MB/s 
[K     |████████████████████████████████| 243 kB 63.3 MB/s 
[K     |████████████████████████████████| 133 kB 65.7 MB/s 
[K     |████████████████████████████████| 144 kB 60.7 MB/s 
[K     |████████████████████████████████| 271 kB 66.5 MB/s 
[K     |████████████████████████████████| 160 kB 65.8 MB/s 
[?25h

### Fine-tuning 6-Billion GPT-J (& other models) in colab with LoRA and 8-bit compression

This notebook is a simple example for fine-tuning [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) with limited memory. A detailed explanation of how it works can be found in [this model card](https://huggingface.co/hivemind/gpt-j-6B-8bit). It is heavily based on [this Colab](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es#scrollTo=vfdLQHOuEU7h). Huge thanks to Hivemind!

You can also finetune [GPT-Neo-2.7B](https://huggingface.co/gustavecortal/gpt-neo-2.7B-8bit), [French GPT-J (Cedille's Boris)](https://huggingface.co/gustavecortal/fr-boris-8bit) and [T0-3B](https://huggingface.co/gustavecortal/T0_3B-8bit) with limited memory.

Twitter: [@gustavecortal](https://twitter.com/gustavecortal)

In [None]:
from sklearn.model_selection import train_test_split

import transformers

import pandas as pd

import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import custom_fwd, custom_bwd

from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise

from tqdm.auto import tqdm

In [None]:
from google.colab import drive
drive._mount("/content/drive")

Mounted at /content/drive


## Converting the model to 8 bits

In [None]:
class FrozenBNBLinear(nn.Module):
    def __init__(self, weight, absmax, code, bias=None):
        assert isinstance(bias, nn.Parameter) or bias is None
        super().__init__()
        self.out_features, self.in_features = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
        self.bias = bias
 
    def forward(self, input):
        output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
        if self.adapter:
            output += self.adapter(input)
        return output
 
    @classmethod
    def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
        weights_int8, state = quantize_blockise_lowmemory(linear.weight)
        return cls(weights_int8, *state, linear.bias)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
 
 
class DequantizeAndLinear(torch.autograd.Function): 
    @staticmethod
    @custom_fwd
    def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
                absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        ctx.save_for_backward(input, weights_quantized, absmax, code)
        ctx._has_bias = bias is not None
        return F.linear(input, weights_deq, bias)
 
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output: torch.Tensor):
        assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
        input, weights_quantized, absmax, code = ctx.saved_tensors
        # grad_output: [*batch, out_features]
        weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
        grad_input = grad_output @ weights_deq
        grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
        return grad_input, None, None, None, grad_bias
 
 
class FrozenBNBEmbedding(nn.Module):
    def __init__(self, weight, absmax, code):
        super().__init__()
        self.num_embeddings, self.embedding_dim = weight.shape
        self.register_buffer("weight", weight.requires_grad_(False))
        self.register_buffer("absmax", absmax.requires_grad_(False))
        self.register_buffer("code", code.requires_grad_(False))
        self.adapter = None
 
    def forward(self, input, **kwargs):
        with torch.no_grad():
            # note: both quantuized weights and input indices are *not* differentiable
            weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
            output = F.embedding(input, weight_deq, **kwargs)
        if self.adapter:
            output += self.adapter(input)
        return output 
 
    @classmethod
    def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
        weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
        return cls(weights_int8, *state)
 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
 
 
def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
    assert chunk_size % 4096 == 0
    code = None
    chunks = []
    absmaxes = []
    flat_tensor = matrix.view(-1)
    for i in range((matrix.numel() - 1) // chunk_size + 1):
        input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
        quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
        chunks.append(quantized_chunk)
        absmaxes.append(absmax_chunk)
 
    matrix_i8 = torch.cat(chunks).reshape_as(matrix)
    absmax = torch.cat(absmaxes)
    return matrix_i8, (absmax, code)
 
 
def convert_to_int8(model):
    """Convert linear and embedding modules to 8-bit with optional adapters"""
    for module in list(model.modules()):
        for name, child in module.named_children():
            if isinstance(child, nn.Linear):
                print(name, child)
                setattr(
                    module,
                    name,
                    FrozenBNBLinear(
                        weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                        bias=child.bias,
                    ),
                )
            elif isinstance(child, nn.Embedding):
                setattr(
                    module,
                    name,
                    FrozenBNBEmbedding(
                        weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
                        absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
                        code=torch.zeros(256),
                    )
                )

You have to Monkey-Patch GPT-J before loading: 

In [None]:
class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)

        convert_to_int8(self.attn)
        convert_to_int8(self.mlp)


class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)
        

class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)


transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock

If you're using another 8-bit quantized model (e.g. T0-3B), remember to Monkey-Patch the model using convert_to_int8()

In [None]:
class T5ForConditionalGeneration(transformers.models.t5.modeling_t5.T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        convert_to_int8(self)

transformers.models.t5.modeling_t5.T5ForConditionalGeneration = T5ForConditionalGeneration

In [None]:
config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Downloading:   0%|          | 0.00/826 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [None]:
config.pad_token_id = config.eos_token_id
tokenizer.pad_token = config.pad_token_id

In [None]:
gpt = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit", low_cpu_mem_usage=True)
#gpt = GPTJForCausalLM.from_pretrained("gustavecortal/fr-boris-8bit", low_cpu_mem_usage=True) French GPT-J Cedille's Boris

k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, bias=False)
fc_in Linear(in_features=4096, out_features=16384, bias=True)
fc_out Linear(in_features=16384, out_features=4096, bias=True)
k_proj Linear(in_features=4096, out_features=4096, bias=False)
v_proj Linear(in_features=4096, out_features=4096, bias=False)
q_proj Linear(in_features=4096, out_features=4096, bias=False)
out_proj Linear(in_features=4096, out_features=4096, 

## LoRA fine-tuning example

You can load my very small dataset composed of philosophical sentences: 

In [None]:
!gdown --id 1Q1WMjny26VHLKb71iTCHIS5zvdm9c-wz

Downloading...
From: https://drive.google.com/uc?id=1Q1WMjny26VHLKb71iTCHIS5zvdm9c-wz
To: /content/pgbp_example.csv
  0% 0.00/41.0k [00:00<?, ?B/s]100% 41.0k/41.0k [00:00<00:00, 30.6MB/s]


In [None]:
data = pd.read_csv('/content/pgbp_example.csv')
data['sentence'] = 'Quote: ' + data['sentence']
train, test = train_test_split(data, test_size=0.01) 
train.to_csv('/content/train_pgbp_example.csv', index=False)
test.to_csv('/content/test_pgbp_example.csv', index=False)

In [None]:
from datasets import load_dataset
dataset = load_dataset('csv', data_files={'train': '/content/train_pgbp_example.csv',
                                              'test': '/content/test_pgbp_example.csv'})

Using custom data configuration default-8e51d309cdf718da


Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-8e51d309cdf718da/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a...


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-8e51d309cdf718da/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["sentence"], padding=True, truncation=True, max_length= 128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence"])
tokenized_datasets.set_format("torch")

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
from torch.utils.data import DataLoader

full_train_dataset = tokenized_datasets["train"]
train_dataloader = DataLoader(full_train_dataset, shuffle=False, batch_size=8)

Add adapters to Embedding/MLP/Attention/LMHead layers

In [None]:
def add_adapters(model, adapter_dim=4, p = 0.1):
    assert adapter_dim > 0

    for name, module in model.named_modules():
      if isinstance(module, FrozenBNBLinear):
          if "attn" in name or "mlp" in name or "head" in name:
              print("Adding adapter to", name)
              module.adapter = nn.Sequential(
                nn.Linear(module.in_features, adapter_dim, bias=False),
                nn.Dropout(p=p),
                nn.Linear(adapter_dim, module.out_features, bias=False),
            )
              print("Initializing", name)
              nn.init.zeros_(module.adapter[2].weight)

          else:
              print("Not adding adapter to", name)
      elif isinstance(module, FrozenBNBEmbedding):
          print("Adding adapter to", name)
          module.adapter = nn.Sequential(
                nn.Embedding(module.num_embeddings, adapter_dim),
                nn.Dropout(p=p),
                nn.Linear(adapter_dim, module.embedding_dim, bias=False),
            )
          print("Initializing", name)
          nn.init.zeros_(module.adapter[2].weight)

add_adapters(gpt)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpt.to(device)

Adding adapter to transformer.wte
Initializing transformer.wte
Adding adapter to transformer.h.0.attn.k_proj
Initializing transformer.h.0.attn.k_proj
Adding adapter to transformer.h.0.attn.v_proj
Initializing transformer.h.0.attn.v_proj
Adding adapter to transformer.h.0.attn.q_proj
Initializing transformer.h.0.attn.q_proj
Adding adapter to transformer.h.0.attn.out_proj
Initializing transformer.h.0.attn.out_proj
Adding adapter to transformer.h.0.mlp.fc_in
Initializing transformer.h.0.mlp.fc_in
Adding adapter to transformer.h.0.mlp.fc_out
Initializing transformer.h.0.mlp.fc_out
Adding adapter to transformer.h.1.attn.k_proj
Initializing transformer.h.1.attn.k_proj
Adding adapter to transformer.h.1.attn.v_proj
Initializing transformer.h.1.attn.v_proj
Adding adapter to transformer.h.1.attn.q_proj
Initializing transformer.h.1.attn.q_proj
Adding adapter to transformer.h.1.attn.out_proj
Initializing transformer.h.1.attn.out_proj
Adding adapter to transformer.h.1.mlp.fc_in
Initializing transfor

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): FrozenBNBEmbedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): FrozenBNBLinear(4096, 4096)
          (v_proj): FrozenBNBLinear(4096, 4096)
          (q_proj): FrozenBNBLinear(4096, 4096)
          (out_proj): FrozenBNBLinear(4096, 4096)
        )
        (mlp): GPTJMLP(
          (fc_in): FrozenBNBLinear(4096, 16384)
          (fc_out): FrozenBNBLinear(16384, 4096)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0

In [None]:
from bitsandbytes.optim import Adam8bit

gpt.gradient_checkpointing_enable()
optimizer = Adam8bit(gpt.parameters(), lr=1e-5, weight_decay=0.01)

In [None]:
num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)

In [None]:
lr_scheduler = transformers.get_linear_schedule_with_warmup(
    optimizer, int(num_training_steps*0.1), num_training_steps
)

In [None]:
filepath = '/content/model.pt'

In [None]:
from tqdm.auto import tqdm

scaler = torch.cuda.amp.GradScaler()
progress_bar = tqdm(range(num_training_steps))
gpt.train()
gpt.gradient_checkpointing_enable()
k = 0

for epoch in range(num_epochs):
    for batch in train_dataloader:

        k = k + 1
        if k % 500 == 0:
          print(k)
          state = {'k' : k, 'epoch': num_epochs, 'lr_scheduler': lr_scheduler.state_dict(), 'state_dict': gpt.state_dict(), 'optimizer': optimizer.state_dict()}
          torch.save(state, filepath)

        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
          out = gpt.forward(**batch,)

          loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),
                                reduction='mean', label_smoothing=0.1)
          
        print(loss)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        lr_scheduler.step()
        progress_bar.update(1)

  0%|          | 0/195 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.3158, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.0310, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9427, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9920, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.1014, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9951, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.7877, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.3448, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4773, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8882, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.4900, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.0338, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.5174, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.6410, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.1510, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.3515, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4751, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9604, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9595, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.3635, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4949, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.0150, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(11.0111, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.8622, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.7936, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4939, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.9784, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4860, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.5992, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.3410, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.7744, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3374, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.3144, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.7441, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.0762, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1815, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2738, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1250, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.6105, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.6222, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4779, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.2384, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.2204, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.3101, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3576, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.0275, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4889, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.6929, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3570, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.5123, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1047, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.4925, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.7831, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.1707, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(10.2951, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.6223, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9537, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9524, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4809, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.5489, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9023, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.9835, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.8256, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.7442, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.5117, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.8770, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4707, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.5515, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3676, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8626, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4984, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2670, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.5621, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1319, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1902, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4538, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1322, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.5908, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.4786, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6148, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1710, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.1320, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2043, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4960, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0266, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.3248, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6965, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.6818, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2881, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9789, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.2571, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.7350, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.9947, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(9.0476, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6175, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.8283, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.7677, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4410, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4898, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6723, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.7815, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.6587, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.5320, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.4079, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.5498, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3106, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.3125, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.1796, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.8200, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5151, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0117, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.2077, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.9697, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0071, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3582, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.9596, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.2137, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(8.0392, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5362, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.8480, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.7825, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.7925, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4155, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.6968, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.8626, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3979, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.8212, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.6896, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.5636, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.6426, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3509, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4555, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.4522, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3055, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3439, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.3076, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1115, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1322, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.1082, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2789, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.2003, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.0913, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.0467, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(7.0537, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.8592, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.9009, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.8462, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6570, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4811, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6868, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7951, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7576, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7866, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3516, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7252, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.9031, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7398, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4902, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6751, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5614, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5979, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4734, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.6134, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.7155, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4326, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1221, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5527, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4954, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.5715, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4055, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4457, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4432, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4396, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4399, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4193, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2522, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3099, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2671, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4735, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4041, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3435, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3509, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.3037, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1492, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2199, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2416, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1641, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.0116, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1526, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.1898, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2912, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2786, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(5.9527, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.2827, device='cuda:0', grad_fn=<AddBackward0>)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


tensor(6.4700, device='cuda:0', grad_fn=<AddBackward0>)


## Text generation example

In [None]:
gpt.eval()
with torch.no_grad():
  prompt = tokenizer("Quote:", truncation=True, padding=True, max_length=128, return_tensors='pt')
  prompt = {key: value.to(device) for key, value in prompt.items()}
  out = gpt.generate(**prompt, max_length=128, top_k=50, top_p=0.9, temperature=1.0, do_sample=True, repetition_penalty = 1.2, num_beams=1)
  print(tokenizer.decode(out[0]))