In [1]:
import io
import os, sys
import requests
import PIL
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from dall_e          import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown

print(torch.__version__)
print(torch.version.cuda)

1.8.0
11.1


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os

import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import skimage

from collections import OrderedDict
import torch

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

## visual codebook

## clip image processing and text processing

In [5]:
import clip

clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [6]:
model, preprocess = clip.load("/hy-tmp/clip_model/ViT-B-32.pt", device)
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


In [7]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f6dbb4829d0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [8]:
expert_knowledge = ["The number of islands on an aircraft carrier is 1, the bow shape is blunt, and the hull has a flat runway.", 
"The number of islands on a destroyer is 1, the bow shape is pointed, and the hull has a fluctuating island.", 
"The number of islands on a cruiser is 2, the bow shape is pointed, and the hull has a fluctuating island.", 
"The number of islands on a supply ship is 2, the bow shape is pointed, and the hull has a fluctuating gantry and a fluctuating island.", 
"A cruise ship has no islands, the bow shape is pointed, and the hull has cabins."]

注释：有字符的地方会编码为向量，没有字符的地方就是0，维度固定为77是因为给句子规定了一个最大长度77。另外注意：句首和句末分别有一个起始符和一个终止符。

CoOp原文中使用的token数量是16。

## self-attention

In [9]:
def KnowledgeTransformer(vit, x, contexts):
    x = vit.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([vit.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + vit.positional_embedding.to(x.dtype)
    x = vit.ln_pre(x)
    # print(x.size())
    # print(contexts.size())
    
    contexts = contexts.repeat(x.size()[0], 1, 1).to(device)
    # print(contexts.size())
    x = torch.cat((x, contexts), 1).half()
    # print(x.size())
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = vit.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD

    x = vit.ln_post(x[:, 0, :])

    if vit.proj is not None:
        x = x @ vit.proj
    return x


class SelfAttention(nn.Module):
    def __init__(self, num_attention_heads, input_size_q, input_size_kv, hidden_size, hidden_dropout_prob):
        super(SelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = hidden_size

        self.query = nn.Linear(input_size_q, self.all_head_size)
        self.key = nn.Linear(input_size_kv, self.all_head_size)
        self.value = nn.Linear(input_size_kv, self.all_head_size)

        self.attn_dropout = nn.Dropout(hidden_dropout_prob)

        # 做完self-attention 做一个前馈全连接 LayerNorm 输出
        # self.dense = nn.Linear(hidden_size, hidden_size)
        # self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
        # self.out_dropout = nn.Dropout(hidden_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # print(new_x_shape)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, input_q, input_k, input_v):
        query_layer = self.transpose_for_scores(self.query(input_q))
        key_layer = self.transpose_for_scores(self.key(input_k))
        value_layer = self.transpose_for_scores(self.value(input_v))
        
        # Cross-attention
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        attention_probs = self.attn_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context = context_layer.view(*new_context_layer_shape)
        # hidden_states = self.dense(context_layer)
        # hidden_states = self.out_dropout(hidden_states)
        # hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return context

    
class LinearClassifier(nn.Module): 
    def __init__(self, input_dim, output_dim): 
        super(LinearClassifier, self).__init__() 
        self.fc = nn.Linear(input_dim, output_dim)
        
    def forward(self, x): 
        x = self.fc(x)
        # print(x.size())
        return F.log_softmax(x, dim=1)
        
        
class ContrastiveClassifier(nn.Module):
    def __init__(self):
        super(ContrastiveClassifier, self).__init__()
        self.t = 0.07
        
    def forward(self, image_feature, text_feature):
        # image_feature = torch.norm(image_feature, dim=-1)  # 32, 768
        # text_feature = torch.norm(text_feature, dim=-1).squeeze(0)    #[5,768]->
        # logits = (image_feature @ text_feature.T) * torch.exp(self.t) #32，5
        
        # 这样写似乎容易出问题
        # image_feature /= image_feature.norm(dim=-1, keepdim=True)
        # text_feature /= text_feature.norm(dim=-1, keepdim=True)
        
        image_feature = F.normalize(image_feature, p=2, dim=-1)
        text_feature = F.normalize(text_feature, p=2, dim=-1)

        logits = (100.0 * image_feature @ text_feature.T).softmax(dim=-1)
        return logits
        
        
class TextEncoder(nn.Module):
    def __init__(self, n_context):
        super().__init__()
        self.transformer = model.transformer
        self.positional_embedding = model.positional_embedding
        self.ln_final = model.ln_final
        self.text_projection = model.text_projection
        self.dtype = model.dtype
        self.n_context = n_context

    def forward(self, prompts, tokenized_prompts, context_feature):
        x = prompts + self.positional_embedding
        
        x = x.permute(1, 0, 2).to(torch.float16)  # NLD -> LND
        # print(x.dtype)
        x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        
        if context_feature:
            x = x[:, 1: self.n_context + 1, :]  # x.shape is [5, 16, 512]
            tokenized_prompts = tokenized_prompts[:, 1: self.n_context]
        
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x
        
class PromptLearner(nn.Module):
    def __init__(self, expert_knowledge):
        super().__init__()
        n_context = 32  # 与coop一致
        context_dim = model.ln_final.weight.shape[0]  # 512
        len_knowledge = len(expert_knowledge)
        
        context_vectors = torch.empty(1, n_context, context_dim)
        context_vectors = context_vectors.repeat(len_knowledge, 1, 1).to(device)
        nn.init.normal_(context_vectors, std=0.02)
        
        self.context = nn.Parameter(context_vectors)
        
        prompt_prefix = " ".join(["X"] * n_context)
        prompts = [prompt_prefix + " " + kl for kl in expert_knowledge]
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        with torch.no_grad():
            embedding = model.token_embedding(tokenized_prompts)
        
        self.register_buffer("token_prefix", embedding[:, :1, :])  # Start Of the Sentence
        self.register_buffer("token_suffix", embedding[:, 1 + n_context :, :])  # Expert Knowledge, End of Sentence
        
        self.n_context = n_context
        self.tokenized_prompts = tokenized_prompts
        
    def forward(self):
        context = self.context
        
        prefix = self.token_prefix
        suffix = self.token_suffix
        prompts = torch.cat([prefix, context, suffix], dim=1)
        
        return prompts, context
        

class MyKnowledgeNet(nn.Module):
    def __init__(self):
        super(MyKnowledgeNet, self).__init__()
        
        enc = load_model("/hy-tmp/vae/encoder.pkl", device)
        dec = load_model("/hy-tmp/vae/decoder.pkl", device)
        params = enc.state_dict()  # 提取出的visual codebook的参数
        self.vc_weight = params["blocks.output.conv.w"]
        self.vc_weight = self.vc_weight.squeeze(2).squeeze(2).unsqueeze(0)
        
        num_attention_heads = 8
        input_size_q = 512
        input_size_kv = 2048
        hidden_size = 768
        hidden_dropout_prob = 0.1
        self.self_attention = SelfAttention(num_attention_heads, input_size_q, input_size_kv, hidden_size, hidden_dropout_prob).to(device)
        
        input_dim = 512
        output_dim = 5
        self.linear = LinearClassifier(input_dim, output_dim).to(device)
        
        self.contrastive = ContrastiveClassifier().to(device)
        
        self.prompt_learner = PromptLearner(expert_knowledge).to(device)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.n_context = self.prompt_learner.n_context
        
        self.text_encoder = TextEncoder(self.n_context)
        
        self.vt = model.visual.to(device)
        
        self.embedding_projection = nn.Linear(512, 768)  # 512是文本向量的维度，768是图像向量的维度，学习这样一个投影层

    def forward(self, images):
        
        prompts, context = self.prompt_learner()  # context.size()=[5,16,512]
        text_feature = self.text_encoder(prompts, self.tokenized_prompts, context_feature=False)
        text_feature = text_feature.to(torch.float32)
        
        context = context.mean(dim=0, keepdim=True).to(torch.float32)
        context = context[:, :16, :]
        k = self.self_attention(context, self.vc_weight, self.vc_weight)
        
        image_feature = KnowledgeTransformer(self.vt, images, k).float()
        CLS_vector = self.contrastive(image_feature, text_feature)
        return CLS_vector


一个问题：如何让context的维度能够使用？是否应该重视每个单词的意义？我想肯定要重视。mean的方法可靠性存疑。

In [10]:
from torch.utils.data import DataLoader

class ImageDataset(torch.utils.data.Dataset):
    
    def __init__(self, img_dir):
        self.imgs = []
        self.labels = []
        label_dict = {'Aircraft_Carrier': 0, 
                      'Amphibious_Assault_Ship': 1, 
                      'Fast_Combat_Support_Ships': 2,
                      'Guided_Missile_Cruiser': 3,
                      'Guided_Missile_Destroyer': 4}
        for label in os.listdir(img_dir):
            dir_path = os.path.join(img_dir, label)
            for img in os.listdir(dir_path):
                img_path = os.path.join(dir_path, img)
                img_label = label
                self.imgs.append(img_path)
                self.labels.append(label_dict[img_label])
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        img = self.imgs[index]
        label = self.labels[index]
        return img, label

In [11]:
img_ships = ImageDataset("/hy-tmp/5_types_ships_small/train")
test_ships = ImageDataset("/hy-tmp/5_types_ships_small/test")
train_loader = torch.utils.data.DataLoader(img_ships, batch_size=96, shuffle=True, drop_last=False, num_workers=32)
test_loader = torch.utils.data.DataLoader(test_ships, batch_size=96, shuffle=False, drop_last=False, num_workers=32)

In [12]:
mymodel = MyKnowledgeNet().to(device)

In [13]:
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

# backbone_params = list(map(id, mymodel.vt.parameters()))
# align_parmas = filter(lambda p: id(p) not in backbone_params, mymodel.parameters())

optimizer = optim.Adam([{'params': mymodel.self_attention.parameters()},
                       # {'params': mymodel.linear.parameters()}, 
                       # {'params': mymodel.contrastive.parameters()}, 
                       {'params': mymodel.embedding_projection.parameters()},
                       {'params': mymodel.prompt_learner.parameters()}], lr=1e-3)
# optimizer = optimizer.to(device)

In [14]:
# 训练模型
for epoch in range(30):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        labels = labels.to(device)
        images = [preprocess(PIL.Image.open(image)) for image in inputs]
        image_input = torch.tensor(np.stack(images)).half().to(device)

#         output = mymodel(image_input)
#         loss = criterion(output, labels)
        
#         optimizer.zero_grad()
#         loss.backward(retain_graph=True)
#         optimizer.step()

        optimizer.zero_grad()  # move zero_grad before the forward pass
        output = mymodel(image_input)
        loss = criterion(output, labels)
        
        loss.backward(retain_graph=True)
        optimizer.step()
        
        if i % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, i * len(data), len(train_loader.dataset), 100. * i / len(train_loader), loss.item()))
            
            
print('Finished Training')

Finished Training


In [15]:
# vt = model.visual
# y = KnowledgeTransformer(vt, image_input, x)
# print(y)
# print(y.size())

In [16]:
torch.save(mymodel.state_dict(), '/hy-tmp/model/model_state_dict_v2_fixed_conloss_32.ptl')

In [17]:
torch.save(mymodel, '/hy-tmp/model/test_model_v2_fixed_conloss_32.ptl')

In [18]:
mymodel.load_state_dict(torch.load('/hy-tmp/model/model_state_dict_v2_fixed_conloss_32.ptl'))

<All keys matched successfully>

In [19]:
top_1_correct = 0
top_2_correct = 0
top_3_correct = 0
total = 0
# set the model to evaluation mode
mymodel.eval()

with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        labels = labels.to(device)
        images = [preprocess(PIL.Image.open(image)) for image in inputs]
        image_input = torch.tensor(np.stack(images)).half().to(device)
        
        outputs = mymodel(image_input)
        _, predicted = torch.topk(outputs.data, k=3, dim=1)
        total += labels.size(0)
        top_1_correct += (predicted[:, 0] == labels).sum().item()
        top_2_correct += ((predicted[:, 0] == labels) | (predicted[:, 1] == labels)).sum().item()
        top_3_correct += ((predicted[:, 0] == labels) | (predicted[:, 1] == labels) | (predicted[:, 2] == labels)).sum().item()

print('Top-1 accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset), 100 * top_1_correct / total))
print('Top-2 accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset), 100 * top_2_correct / total))
print('Top-3 accuracy of the network on the %d test images: %.2f %%' % (len(test_loader.dataset), 100 * top_3_correct / total))

Top-1 accuracy of the network on the 449 test images: 77.28 %
Top-2 accuracy of the network on the 449 test images: 88.64 %
Top-3 accuracy of the network on the 449 test images: 94.65 %
