-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Description
Hi! For a project I am working on, I created a PyTorch LightningModule that contains a CLIP model as well as a linear layer and finetuned it on my own data. This may not have been the most ideal setup, but nonetheless I was able to get accuracies in the high 90s when first doing this. However, I saved a checkpoint using trainer.save_checkpoint(...), which when I load back in and test on the same data, only gets guessing accuracy, as if the model had never been trained at all. This wrapper approach I'm using has worked for other similar setups, like using torchvision ResNet models inside a LightningModule with an extra linear layer attached, so I feel like it could be a CLIP-specific issue with how I'm handling things.
Here is an abbreviated version of what my LightningModule looks like:
class ClipWrapper(pl.LightningModule):
def __init__(self, model, preprocess, num_features, num_classes, lr=1e-8):
super(ClipWrapper, self).__init__()
self.save_hyperparameters()
self.lr = lr
self.feature_size = num_features
self.preprocess = preprocess
self.clipmodel = model
self.ff_head = nn.Linear(self.feature_size, num_classes)
...
def forward(self, x):
encodings = self.clipmodel.encode_image(x)
logits = self.ff_head(encodings.float())
return logitsAnd when I look at the state_dict saved by pytorch lightning (i.e. when I print out ckpt['state_dict'].keys()) I'm seeing the parameters that start with clipmodel. as well as the parameters for my extra linear layer, ff_head.weight and ff_head.bias. However, neither of these methods of trying to load in the checkpoint work; they both run just fine, but afterwards the model is still getting only guessing accuracy.
## method 1
newmodel = ClipWrapper.load_from_checkpoint("saved.ckpt")
model.eval()
## method 2
# this is the same code I used to initialize the model when it was trained:
clipmodel, preprocess = clip.load("ViT-B/32")
clip.model.convert_weights(clipmodel)
model = clipwrapper.ClipWrapper(clipmodel, preprocess, n_features, n_classes, lr=learning_rate)
# then i load in the checkpoint and just update the state_dict
ckpt = th.load("saved.ckpt")
model.load_state_dict(ckpt['state_dict'])
model.eval()I was in an interactive session and was able to figure out that while some of the states were changing after loading in a checkpoint, a subset of the clipmodel.visual.transformer.resblocks parameters were becoming new values that weren't in the initial model or the checkpoint (unless I'm mistaken, I might be confusing myself at this point). Is there something about state_dicts with CLIP that I'm missing, like parts of the model becoming reinstantiated when the model is being loaded in? It could be that I was saving my checkpoint wrong, but I honestly don't know how else to do it other than using torch.save to directly save the model.state_dict(), which I feel like would net the same results, right?
Apologies for the stream-of-consciousness type of post, but let me know if there's any more information I can provide. It's probably just an issue on my end as I'm quite new to training large models, but I just want to check here in case anybody else has had a similar issue.