In [1]:
from transformers import AutoTokenizer, BertModel
from torch import nn
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("StanfordAIMI/RadBERT")
model = BertModel.from_pretrained("StanfordAIMI/RadBERT").to('cuda')

Some weights of the model checkpoint at StanfordAIMI/RadBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
class RadBertEmbedder(nn.Module):
    _device = f"{torch.device('cuda' if torch.cuda.is_available() else 'cpu')}:{torch.cuda.current_device()}"
    _tokenizer = tokenizer
    _model = model

    def __init__(self, emb_dim=32,*args, **kwargs):
        super().__init__()
        self.emb_dim = emb_dim
        self.mlp = nn.Sequential(
            nn.Linear(768, emb_dim), # 768 bert output的维度
            nn.LayerNorm(emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim),
            nn.LayerNorm(emb_dim),
        ).to(self._device)

    def forward(self, condition):
        inputs = self._tokenizer(condition, return_tensors="pt").to(self._device)
        outputs = self._model(**inputs)
        c = outputs.last_hidden_state[:, 0] # [CLS]
        c = self.mlp(c)
        return c

In [19]:
emb_obj = RadBertEmbedder(emb_dim=1024)
condition = ["A photo of a lung xray with a visible pleural effusion"]*64
c = emb_obj.forward(condition)
c.shape

torch.Size([64, 1024])