<a href="https://colab.research.google.com/github/d98s/TrialNotebook/blob/main/Untitled11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
import json
# data_dir = '../Downloads/annotations_trainval2014'
data_dir = './'
data_type = 'val2014'
anno = '{}/annotations/instances_{}.json'.format(data_dir, data_type)
coco = COCO(anno)
categories = coco.loadCats(coco.getCatIds())
category_names = [cats['name'] for cats in categories]
f = open(anno)
annotations = json.load(f)

In [None]:
cat_id = {}
cat_box = {}
for annot in annotations['annotations']:
    if annot['category_id'] not in cat_id.keys():
        cat_id[annot['category_id']] = []
        cat_box[annot['category_id']] = []
    cat_id[annot['category_id']].append(annot['image_id'])
    cat_box[annot['category_id']].append(annot['bbox'])

indices = torch.randint(2000, (1000,))

arranged_id = []
arranged_box = []
for idx in indices:
    arranged_id.append(cat_id[1][idx])
    arranged_id.append(cat_id[2][idx])
    arranged_id.append(cat_id[3][idx])
    arranged_box.append(cat_box[1][idx])
    arranged_box.append(cat_box[2][idx])
    arranged_box.append(cat_box[3][idx])

In [None]:
class CocoDataset():
    def __init__(self, img_ids, bboxes, img_path):
        self.img_ids  = img_ids
        self.bboxes = bboxes
        self.img_path = img_path
        self.prefix = '000000000000'
        self.len = len(self.img_ids)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_name = self.prefix[0:12-len(str(img_id))] + str(img_id) + '.jpg'
        img = np.asarray(Image.open(self.img_path + 'COCO_val2014_' + img_name))
        x, y, w, h = int(self.bboxes[idx][0]), int(self.bboxes[idx][1]), int(self.bboxes[idx][2]), int(self.bboxes[idx][3])
        w = w+1 if w==0 else w
        h = h+1 if h==0 else h
        img = processor(Image.fromarray(img[y:y+h,x:x+w]))
        return img

In [None]:
class TextEncoder(nn.Module):
  def __init__(self, clip_model):
    super(TextEncoder, self).__init__()
    self.positional_embedding = clip_model.positional_embedding
    self.transformer = clip_model.transformer
    self.ln_final = clip_model.ln_final
    self.text_projection = clip_model.text_projection
    self.dtype = clip_model.dtype

  def forward(self, prompt_embeddings, prompt_tokens):
    x = prompt_embeddings + self.posiional_embedding.type(self.dtype)
    x = x.permute(1,0,2)    # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1,0,2)    # LND -> NLD
    x = self.ln_final(x).type(self.dtype)
    x =x[torch.arange(x.shape[0]), prompt_tokens.argmax(dim=-1)] @ self.text_projection.type(self.dtype)
    return x

In [None]:
class Prompt(nn.Module):
  def __init__(self, clip_model, n_ctx, class_names):
    super(Prompt, self).__init__()
    self.n_ctx = n_ctx
    self.class_names= class_names
    self.n_class = len(self.class_names)
    self.dtype = clip_model.dtype
    ctx_dim = clip_model.ln_final.weight.shape[0]
    ctx_vec = torch.empty(self.n_ctx, ctx_dim, dtype=self.dtype)
    nn.init.normal_(ctx_vec, std=0.02)
    prompt_prefix = " ".join(["X"] * self.n_ctx)
    self.ctx = nn.Parameter(ctx_vec)
    prompts = [prompt_prefix + " " + names + "." for names in self.calss_names]
    self.prompt_tokens = torch.cat([clip.tokenize(prompt) for prompt in prompts])
    with torch.no_grad():
      embedding = clip_model.token_embedding(self.prompt_tokens).type(self.dtype)
    self.prefix_embedding = embedding[:, :1, :]
    self.suffix_embedding = embedding[:, n_ctx+1:, :]

  def forward(self):
    ctx = self.ctx.unsqueeze(0).expand(self.n_class, -1, -1)
    prompt_embeddings = torch.cat([self.prefix_embedding, self.ctx, self.suffix_embedding], dim=1)
    return prompt_embeddings

In [None]:
class CustomCLIP(nn.Module):
  def __init__(self, clip_model, n_ctx, class_names):
    self.prompt = Prompt(clip_model, n_ctx, class_names)
    self.prompt_tokens = self.prompt.prompt_tokens
    self.text_encoder = TextEncoder(clip_model)
    self.image_encoder = clip_model.visual
    self.dtype = clip_model.dtype

  def forward(self, images):
    prompt_embeddings = self.prompt()
    text_encodings = self.text_encoder(prompt_embeddings, self.prompt_tokens)
    text_encodings /= text_encodings.norm(dim=-1, keepdim=True)
    image_encodings = self.image_encoder(images.type(self.dtype))
    image_encodings /= image_encodings.norm(dim=-1, keepdim=True)
    sim = image_encodings @ text_encodings.T
    return sim

In [None]:
def train(custom_clip, dataloader, optim, criterion, num_epochs, device):
  clip_model.train()
  for epoch in range(num_epochs):
    print('epoch ', epoch)
    running_loss = 0
    for inputs in tqdm(dataloader):
      images = inputs.to(device)
      optim.zero_grad()
      sim = custom_clip(images)
      pred = sim.argmax(dim=-1)
      loss = criterion(pred, torch.arange(80))
      loss.backward()
      optim.step()
      print('loss: ', loss.item())
      # running_loss += loss.item()*inputs.size(0)

In [None]:
def freeze_parameters(model):
    for param in model.parameters():
        param.requires_grad = False

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model, processor = clip.load('ViT-B/32', device=device)

In [None]:
dataset = CocoDataset(arranged_id, arranged_box, './val2014/')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=False)

In [None]:
n_ctx = 5
custom_clip = CustomCLIP(clip_model, n_ctx, category_names).to(device)
optimizer = torch.optim.Adam(custom_clip.prompt.ctx, lr=1e-5)
criterion = nn.CrossEntropyLoss()
train(clip_model, dataloader, optimizer, criterion, 200, device)