Computational graph not retained for BERT #70
Comments
Can you summarize the proto-MAML approach briefly? Also why is |
In ProtoMAML, the network representations are first computed for all the data points in the support set. Then a "prototype" is generated for each class based on the mean representations of data points belonging to each class. This prototype is then used to initialize the final classification (softmax) layer - the weight is 2 * prototype and the bias is negative squared norm of the prototype. The outer-loop backprop requires gradients from two parts - the inner-loop updated model as well as the original classification layer initialization. I used |
It's not clear to me if you need Anyway, the issue you're having comes from |
@egrefen The issue also occurs when import higher
import torch
from torch import nn, optim
from transformers import BertModel, BertTokenizer
from torch.nn import functional as F
class BaseModel(nn.Module):
def __init__(self, encoder, max_length, device):
super(BaseModel, self).__init__()
self.max_length = max_length
self.device = device
if encoder == 'bert':
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.encoder = BertModel.from_pretrained('bert-base-uncased')
self.encoder.pooler.dense.weight.requires_grad = False
self.encoder.pooler.dense.bias.requires_grad = False
self.n_tunable_layers = 1
tunable_layers = {str(l) for l in range(12 - self.n_tunable_layers, 12)}
for name, param in self.encoder.named_parameters():
if not set.intersection(set(name.split('.')), tunable_layers):
param.requires_grad = False
elif encoder == 'lstm':
self.encoder = nn.LSTM(batch_first=True, input_size=32, hidden_size=768)
self.linear = nn.Linear(768, 192)
self.to(self.device)
def encode_text(self, text):
if isinstance(self.encoder, BertModel):
encode_result = self.tokenizer.batch_encode_plus(text, return_token_type_ids=False, max_length=self.max_length,
pad_to_max_length=True, return_tensors='pt')
for key in encode_result:
encode_result[key] = encode_result[key].to(self.device)
return encode_result
elif isinstance(self.encoder, nn.LSTM):
return torch.randn((len(text), 32, 32), device=self.device)
def forward(self, inputs):
if isinstance(self.encoder, BertModel):
out, _ = self.encoder(inputs['input_ids'], attention_mask=inputs['attention_mask'])
elif isinstance(self.encoder, nn.LSTM):
out, _ = self.encoder(inputs)
out = out[:, 1:-1, :]
out = self.linear(out)
return out
class ProtoMAML:
def __init__(self, device, encoder):
self.output_layer_weight = None
self.output_layer_bias = None
self.learner = BaseModel(encoder=encoder, max_length=32, device=device)
self.inner_optimizer = optim.SGD([p for p in self.learner.parameters() if p.requires_grad], lr=0.001)
self.loss_fn = nn.CrossEntropyLoss()
self.output_lr = 0.001
self.device = device
self.updates = 5
def output_layer(self, input, weight, bias):
return F.linear(input, self.output_layer_weight + weight, self.output_layer_bias + bias)
def initialize_with_proto_weights(self, support_repr, support_label, n_classes):
prototypes = self.build_prototypes(support_repr, support_label, n_classes)
weight = 2 * prototypes
bias = -torch.norm(prototypes, dim=1) ** 2
self.output_layer_weight = torch.zeros_like(weight, requires_grad=True)
self.output_layer_bias = torch.zeros_like(bias, requires_grad=True)
return weight, bias
def build_prototypes(self, data_repr, data_label, num_outputs):
n_dim = data_repr.shape[2]
data_repr = data_repr.view(-1, n_dim)
data_label = data_label.view(-1)
prototypes = torch.zeros((num_outputs, n_dim), device=self.device)
for c in range(num_outputs):
idx = torch.nonzero(data_label == c).view(-1)
if idx.nelement() != 0:
prototypes[c] = torch.mean(data_repr[idx], dim=0)
return prototypes
def initialize_output_layer(self, n_classes):
self.output_layer_weight = torch.randn((n_classes, 768), requires_grad=True)
self.output_layer_bias = torch.randn(n_classes, requires_grad=True)
def train(self, support_text, labels, n_classes, n_iter):
for itr in range(n_iter):
print('Iteration ', itr)
self.learner.zero_grad()
self.initialize_output_layer(n_classes)
x = self.learner.encode_text(support_text)
y = labels.to(device)
output_repr = self.learner(x)
init_weights, init_bias = self.initialize_with_proto_weights(output_repr, y, n_classes)
with higher.innerloop_ctx(self.learner, self.inner_optimizer,
copy_initial_weights=False,
track_higher_grads=True) as (flearner, diffopt):
for i in range(self.updates):
output = flearner(x)
output = self.output_layer(output, init_weights, init_bias)
output = output.view(output.size()[0] * output.size()[1], -1)
loss = self.loss_fn(output, y)
output_weight_grad, output_bias_grad = torch.autograd.grad(loss, [self.output_layer_weight, self.output_layer_bias],
retain_graph=True)
self.output_layer_weight = self.output_layer_weight - self.output_lr * output_weight_grad
self.output_layer_bias = self.output_layer_bias - self.output_lr * output_bias_grad
diffopt.step(loss)
if __name__ == '__main__':
encoder = 'bert' # or 'lstm'
support_text = [['This is a support text']] * 64
labels = torch.randint(0, 10, (64 * 30, ))
n_classes = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProtoMAML(device=device, encoder=encoder)
model.train(support_text, labels, n_classes, n_iter=10) |
Okay, I will reopen the issue but this code is quite complex. Can you please try to reimplement, with as few external dependencies as possible, a minimal example which reproduces the error? Does it really only crop up with huggingface's transformers? It's really time consuming to dig through a lot of code, understand it, and figure out where the bug is, and I'm happy to do that eventually if it's the only way, but I don't have a lot of time to do this right now so this would really speed things up. |
I tried to remove the components one by one until I could still replicate the issue. Now it's no longer ProtoMAML-specific but just using higher on top of BERT for sequence labelling. If you set import higher
import torch
from torch import nn, optim
from transformers import BertModel, BertTokenizer
class BaseModel(nn.Module):
def __init__(self, encoder, max_length, device):
super(BaseModel, self).__init__()
self.max_length = max_length
self.device = device
if encoder == 'bert':
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.encoder = BertModel.from_pretrained('bert-base-uncased')
self.encoder.pooler.dense.weight.requires_grad = False
self.encoder.pooler.dense.bias.requires_grad = False
self.n_tunable_layers = 1
tunable_layers = {str(l) for l in range(12 - self.n_tunable_layers, 12)}
for name, param in self.encoder.named_parameters():
if not set.intersection(set(name.split('.')), tunable_layers):
param.requires_grad = False
elif encoder == 'lstm':
self.encoder = nn.LSTM(batch_first=True, input_size=32, hidden_size=768)
self.linear = nn.Linear(768, 10)
self.to(self.device)
def encode_text(self, text):
if isinstance(self.encoder, BertModel):
encode_result = self.tokenizer.batch_encode_plus(text, return_token_type_ids=False, max_length=self.max_length,
pad_to_max_length=True, return_tensors='pt')
for key in encode_result:
encode_result[key] = encode_result[key].to(self.device)
return encode_result
elif isinstance(self.encoder, nn.LSTM):
return torch.randn((len(text), 32, 32), device=self.device)
def forward(self, inputs):
if isinstance(self.encoder, BertModel):
out, _ = self.encoder(inputs['input_ids'], attention_mask=inputs['attention_mask'])
elif isinstance(self.encoder, nn.LSTM):
out, _ = self.encoder(inputs)
out = out[:, 1:-1, :]
out = self.linear(out)
return out
class ProtoMAML:
def __init__(self, device, encoder):
self.output_layer_weight = None
self.output_layer_bias = None
self.learner = BaseModel(encoder=encoder, max_length=32, device=device)
self.inner_optimizer = optim.SGD([p for p in self.learner.parameters() if p.requires_grad], lr=0.001)
self.loss_fn = nn.CrossEntropyLoss()
self.output_lr = 0.001
self.device = device
self.updates = 5
def train(self, support_text, labels, n_iter):
for itr in range(n_iter):
print('Iteration ', itr)
self.learner.zero_grad()
x = self.learner.encode_text(support_text)
y = labels.to(device)
with higher.innerloop_ctx(self.learner, self.inner_optimizer,
copy_initial_weights=False,
track_higher_grads=True) as (flearner, diffopt):
for i in range(self.updates):
output = flearner(x)
output = output.view(output.size()[0] * output.size()[1], -1)
loss = self.loss_fn(output, y)
diffopt.step(loss)
if __name__ == '__main__':
encoder = 'bert' # or 'lstm'
support_text = [['This is a support text']] * 64
labels = torch.randint(0, 10, (64 * 30, ))
n_classes = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ProtoMAML(device=device, encoder=encoder)
model.train(support_text, labels, n_iter=10) |
Hello @Nithin-Holla. Thanks for the pared down example, although I wouldn't call this minimal. It would have been helpful to get some tracebacks/outputs when things fail (and what should have happened) and sample output when things don't fair (e.g. when using LSTM). It also would have been useful to get a I've tried to set up an env to repro some error (since there isn't an actual clear statement of what that is) and get the following failure: Traceback (most recent call last):
File "bert.py", line 88, in <module>
model.train(support_text, labels, n_iter=10)
File "bert.py", line 66, in train
x = self.learner.encode_text(support_text)
File "bert.py", line 32, in encode_text
pad_to_max_length=True, return_tensors='pt')
File "/usr/local/anaconda3/lib/python3.7/site-packages/transformers/tokenization_utils_base.py", line 1832, in batch_encode_plus
**kwargs,
File "/usr/local/anaconda3/lib/python3.7/site-packages/transformers/tokenization_utils.py", line 534, in _batch_encode_plus
ids, pair_ids = ids_or_pair_ids
ValueError: not enough values to unpack (expected 2, got 1) I'm not sure this corresponds to the sort of error you're getting. Can you please help elucidate this matter? |
This error is coming from within the new transformers release which I didn't have before. Anyway, I bumped up some of the packages I had and the original issue is now gone. I guess it was caused by one of the low-level libraries and not from higher. You can close this for now. Thanks! |
I'm trying to implement first-order version of ProtoMAML (https://arxiv.org/pdf/1903.03096.pdf) for a sequence labelling task. If I use BERT as encoder, I run into this error at the line
diffopt.step
:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
. Instead, if I use an LSTM as an encoder, then it runs successfully. Perhaps the graph is purged somehow since BERT is a large model?Here is a self-contained code to replicate the issue. The issue occurs on both CPU and GPU. On line 117, you can specify the encoder as
bert
orlstm
. It requires thetransformers
library from HuggingFace to run.The text was updated successfully, but these errors were encountered: