In [1]:
import argparse
import os
from tqdm.notebook import tqdm
import numpy as np
import mindspore as ms
from mindspore import ops
from mindnlp.transformers import (
    BertGenerationTokenizer,
    BertGenerationDecoder,
    BertGenerationConfig,
    CLIPModel,
    CLIPTokenizer
)
from loaders.ZO_Clip_loaders import tinyimage_single_isolated_class_loader
from sklearn.metrics import roc_auc_score
from mindspore import context
import sys

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 1.277 seconds.
Prefix dict has been built successfully.


In [2]:
def tokenize_for_clip(batch_sentences, tokenizer):
    # 使用CLIPTokenizer直接处理
    inputs = tokenizer(
        batch_sentences,
        padding=True,
        truncation=True,
        max_length=77,
        return_tensors="ms"
    )
    return inputs.input_ids

In [3]:
def greedysearch_generation_topk(clip_embed, bert_model, batch_size=32):
    # 处理多个样本
    N = clip_embed.shape[0]
    max_len = 77

    # 初始化batch的target序列
    target_lists = [[berttokenizer.bos_token_id] for _ in range(N)]
    top_k_lists = [[] for _ in range(N)]
    bert_model.set_train(False)

    for i in range(max_len):
        # 批量处理target序列
        targets = ms.Tensor(target_lists, dtype=ms.int64)
        position_ids = ms.Tensor(np.arange(targets.shape[1])[None].repeat(N, axis=0), ms.int32)
        attention_mask = ops.ones((N, targets.shape[1]), dtype=ms.int32)

        out = bert_model(
            input_ids=targets,
            attention_mask=attention_mask,
            position_ids=position_ids,
            encoder_hidden_states=clip_embed,
        )

        pred_idxs = out.logits.argmax(axis=2)[:, -1].astype(ms.int64)
        _, top_k = ops.topk(out.logits, dim=2, k=35)

        for j in range(N):
            target_lists[j].append(pred_idxs[j].item())
            top_k_lists[j].append(top_k[j, -1])

        if all(len(t) >= 10 for t in target_lists):
            break

    results = []
    for i in range(N):
        top_k_tensor = ops.concat(top_k_lists[i])
        target_tensor = ms.Tensor(target_lists[i], dtype=ms.int64)
        results.append((target_tensor, top_k_tensor))

    return results


In [4]:
def image_decoder(clip_model, berttokenizer, split, image_loaders=None, bert_model=None):
    seen_labels = split[:20]
    seen_descriptions = [f"This is a photo of a {label}" for label in seen_labels]
    targets = ms.Tensor(1000 * [0] + 9000 * [1], dtype=ms.int32)
    max_num_entities = 0
    ood_probs_sum = []

    for semantic_label in tqdm(split):
        # print(f"处理类别: {semantic_label}")
        loader = image_loaders[semantic_label]

        for batch_data in loader.create_dict_iterator():
            batch_images = batch_data["image"]
            batch_size = batch_images.shape[0]

            clip_model.set_train(False)
            clip_out = clip_model.get_image_features(pixel_values=batch_images)
            clip_extended_embed = ops.repeat_elements(clip_out, rep=2, axis=1)
            clip_extended_embed = ops.expand_dims(clip_extended_embed, 1)

            batch_results = greedysearch_generation_topk(clip_extended_embed, bert_model)
            del clip_extended_embed
            del clip_out

            batch_target_tokens = []
            batch_topk_tokens = []

            for target_list, topk_list in batch_results:
                target_tokens = [berttokenizer.decode(int(pred_idx.asnumpy())) for pred_idx in target_list]
                topk_tokens = [berttokenizer.decode(int(pred_idx.asnumpy())) for pred_idx in topk_list]
                batch_target_tokens.append(target_tokens)
                batch_topk_tokens.append(topk_tokens)

            batch_unique_entities = []
            for topk_tokens in batch_topk_tokens:
                unique_entities = list(set(topk_tokens) - set(seen_labels))
                batch_unique_entities.append(unique_entities)
                max_num_entities = max(max_num_entities, len(unique_entities))

            batch_all_desc = []
            for unique_entities in batch_unique_entities:
                all_desc = seen_descriptions + [f"This is a photo of a {label}" for label in unique_entities]
                batch_all_desc.append(all_desc)

            batch_all_desc_ids = [tokenize_for_clip(all_desc, cliptokenizer) for all_desc in batch_all_desc]

            image_features = clip_model.get_image_features(pixel_values=batch_images)
            image_features = image_features / ops.norm(image_features, dim=-1, keepdim=True)

            for b_idx in range(len(batch_results)):
                text_features = clip_model.get_text_features(input_ids=batch_all_desc_ids[b_idx])
                text_features = text_features / ops.norm(text_features, dim=-1, keepdim=True)

                similarity = 100.0 * (image_features[b_idx:b_idx + 1] @ text_features.T)
                zeroshot_probs = ops.softmax(similarity, axis=-1).squeeze()

                ood_prob_sum = float(ops.sum(zeroshot_probs[20:]).asnumpy())
                ood_probs_sum.append(ood_prob_sum)

            del batch_target_tokens
            del batch_topk_tokens
            del batch_unique_entities
            del batch_all_desc
            del image_features

    auc_sum = roc_auc_score(targets.asnumpy(), np.array(ood_probs_sum))
    print('当前split的sum_ood AUROC={}'.format(auc_sum))
    return auc_sum


In [5]:
def get_args_in_notebook():
    args = argparse.Namespace(
        trained_path='./trained_models/COCO/'
    )
    return args

In [6]:
if __name__ == '__main__':
    # 判断是否在notebook环境
    if 'ipykernel' in sys.modules or 'IPython' in sys.modules:
        args = get_args_in_notebook()
        context.set_context(device_target="Ascend")
    else:
        parser = argparse.ArgumentParser()
        parser.add_argument('--trained_path', type=str, default='./trained_models/COCO/')
        args = parser.parse_args()
        context.set_context(device_target="Ascend")

    args.saved_model_path = args.trained_path + '/ViT-B32/'

    if not os.path.exists(args.saved_model_path):
        os.makedirs(args.saved_model_path)

    # 初始化tokenizers
    berttokenizer = BertGenerationTokenizer.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder')

    # 加载CLIP模型和tokenizer
    model_name = 'openai/clip-vit-base-patch32'
    try:
        clip_model = CLIPModel.from_pretrained(model_name)
        cliptokenizer = CLIPTokenizer.from_pretrained(model_name)
    except Exception as e:
        print(f"Error loading model from mirror, trying direct download: {e}")
        clip_model = CLIPModel.from_pretrained(model_name)
        cliptokenizer = CLIPTokenizer.from_pretrained(model_name)

    # 初始化BERT模型
    if (not os.path.exists(f"{args.saved_model_path}/decoder_model")):
        bert_config = BertGenerationConfig.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
        bert_config.is_decoder = True
        bert_config.add_cross_attention = True
        bert_config.return_dict = True
        bert_model = BertGenerationDecoder.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder",
                                                           config=bert_config)
    else:
        bert_model = BertGenerationDecoder.from_pretrained(f"{args.saved_model_path}/decoder_model")

    splits, tinyimg_loaders = tinyimage_single_isolated_class_loader(dataset_dir='./data/tiny-imagenet-200/val/',
                                                                     labels_to_ids_path='./dataloaders/tinyimagenet_labels_to_ids.txt')

    sum_scores = []
    for split in splits:
        sum_score = image_decoder(clip_model, berttokenizer, split=split,
                                  image_loaders=tinyimg_loaders, bert_model=bert_model)
        sum_scores.append(sum_score)

    print('5个split的sum auc分数:', sum_scores)
    print('5个split的平均分数:', np.mean(sum_scores), '标准差:', np.std(sum_scores))


BertGenerationDecoder has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


  0%|          | 0/10 [00:00<?, ?it/s]

当前split的sum_ood AUROC=0.8469840000000002


  0%|          | 0/10 [00:00<?, ?it/s]

当前split的sum_ood AUROC=0.8381759999999999


  0%|          | 0/10 [00:00<?, ?it/s]

当前split的sum_ood AUROC=0.8535600000000001


  0%|          | 0/10 [00:00<?, ?it/s]

当前split的sum_ood AUROC=0.8344800000000001


  0%|          | 0/10 [00:00<?, ?it/s]

当前split的sum_ood AUROC=0.8711199999999999
5个split的sum auc分数: [0.8469840000000002, 0.8381759999999999, 0.8535600000000001, 0.8344800000000001, 0.8711199999999999]
5个split的平均分数: 0.8488640000000001 标准差: 0.012977281317749062
