In [6]:
from distutils.command.register import register
import os
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
# import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
# from .Adpater.models import register
from extractor import addImagePath, textExtraction, imageExtraction, textExtractReverse
#
from lavis.models.blip_models.blip import BlipBase
from lavis.models.blip_models.blip_outputs import (
    BlipOutput,
    BlipIntermediateOutput,
)
from Main.Adpater.med import XBertLMHeadDecoder
from Main.Adpater.vit import VisionTransformerEncoder, Block
# # from omegaconf import OmegaConf

# 1. Load the data and split the data

In [16]:
class OxfordDataset(torch.utils.data.Dataset):
    def __init__(self, text, image, funny_score):
        self.text = text
        self.image = image
        self.funny_score = funny_score

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        text = textExtraction(self.text[idx])
        image = imageExtraction(self.image[idx])
        funny_score = torch.tensor(float(self.funny_score[idx])).unsqueeze(0)
        return text, image, funny_score

In [18]:
# if args.img - dir == 'Oxford_HIC':
#     dirPath = '../Data/Oxford_HIC/oxford_hic_data.csv'
#     imgPath = '../Data/Oxford_HIC/oxford_img/'
# else:
dirPath = '../Data/Instagram/Filter_' + 'wendys' + '.csv'
imgPath = '../Data/Instagram/' + 'wendys' + '_img/'
# load data
data = pd.read_csv(dirPath)
data = addImagePath(data, imgPath)
# split data
train, test = train_test_split(data, test_size=0.2, random_state=42)

train_text = train['caption'].tolist()
train_image = train['image_id'].tolist()
train_funny_score = train['funny_score'].tolist()
test_text = test['caption'].tolist()
test_image = test['image_id'].tolist()
test_funny_score = test['funny_score'].tolist()

train_dataset = OxfordDataset(train_text, train_image, train_funny_score)
test_dataset = OxfordDataset(test_text, test_image, test_funny_score)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
dataL = iter(train_loader)
text, imgs, funny_score = next(dataL)
print("shape of text: ", text.shape)
print("shape of image: ", imgs.shape)
print("shape of funny_score: ", funny_score.shape)

In [8]:
### 官方的Gemma #########################################################################################
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b", device_map="auto",  torch_dtype=torch.bfloat16)
gemmaConfig =  AutoConfig.from_pretrained('google/gemma-2-2b')
########################################################################################################

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


ValueError: Tokenizer class GemmaTokenizer does not exist or is not currently imported.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
@register('blip_caption')
class BlipCaption(BlipBase):
    """
    BLIP captioning model for Screen2words.

    Supported model types:
        - base_coco: fine-tuned BLIP base model on COCO caption dataset (Karparthy split).

    Pretrained weight url :
        - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_coco_caption_base.pth
    """

    def __init__(self,vit_type='base', med_config_path='./Adapter/configs/med_config.json',adapter_type=None, bert_adapter=None, visual_projection=None, tune_language=None, prompt=None, max_txt_len=40):
        super().__init__()
        # self.tokenizer = self.init_tokenizer()
        # vision encoder
        self.visual_encoder = VisionTransformerEncoder.from_config(vit_type = vit_type, adapter_type=adapter_type)
        # text encoder + multimodal decoder
        self.text_decoder = XBertLMHeadDecoder.from_config(med_config_path, False)

        self.prompt = prompt
        self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1

        self.max_txt_len = max_txt_len
        self.adapter_type = adapter_type
        self.bert_adapter = bert_adapter
        self.tune_language = tune_language
        # if((self.bert_adapter or self.tune_language) and self.adapter_type == None):
        if visual_projection == "linear":
            self.VLBridge = nn.Linear(768, 768)
        elif visual_projection == "ViT_block":
            self.VLBridge = Block(dim=768,num_heads=12,)
        else:
            self.VLBridge = None

        # feed forward
        self.feedForwardLinear = nn.Linear(768, 768)
        self.feedForwardLayerNorm = nn.LayerNorm(768)

        # gemma
        self.gemmaLinearMaxTokens = nn.Linear(64, 16)
        self.gemmaLinearBefore = nn.Linear(768, gemmaConfig.vocab_size)
        self.gemmaSoftmax = nn.Softmax(dim=2)
        self.gemma = nn.Sequential(*list(gemma.children())[:-1])
        self.gemmaLm_head = nn.Sequential(*list(gemma.children())[1:])

        # funny score
        self.FunnyScorelinear1 = nn.Linear(768, 1)
        self.FunnyScorelinear2 = nn.Linear(64, 1)

    def gemmaGenerate(self, x):
        with torch.no_grad():
            # maximum 32 tokens
            x = self.gemmaLinearMaxTokens(x.transpose(1, 2)).transpose(1, 2)
            x = self.gemmaLinearBefore(x)
            x = self.gemmaSoftmax(x)
            # get max value of each row, total 32*64
            top_k_values, top_k_indices = torch.topk(x, 1, dim=2, largest=True)
            toGemma = textExtractReverse(top_k_indices).to(device)
            # 使用gemma作為model的一部分
            output = self.gemma(toGemma)
            # output[0] = last_hidden_state
            # output[1] = past_key_values
        return output[0]

    def forward_encoder(self, samples):
        image_embeds = self.visual_encoder.forward_features(samples["image"])
        return image_embeds

    def forward_decoder(self, text, image_embeds):
        # prepare inputs for forwarding decoder
        # text = self.tokenizer(
        #     raw_text,
        #     padding="longest",
        #     truncation=True,
        #     max_length=self.max_txt_len,
        #     return_tensors="pt",
        # ).to(self.device)
        text.input_ids[:, 0] = self.tokenizer.bos_token_id

        # prepare targets for forwarding decoder
        decoder_targets = text.input_ids.masked_fill(
            text.input_ids == self.tokenizer.pad_token_id, -100
        )
        decoder_targets[:, : self.prompt_length] = -100

        # forward decoder
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            self.device
        )
        decoder_output = self.text_decoder(
            input_ids=text.input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            labels=decoder_targets,
            return_dict=True,
        )

        return decoder_output, decoder_targets

    def forward(self, text, image):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
                - text_input (list): A list of strings of length batch_size.
        Returns:
            output (BlipOutput): A BlipOutput object containing the following
                attributes:
                - loss (torch.Tensor): A scalar tensor containing the total loss. For BlipCaption, this is the same as the LM loss.
                - loss_lm (torch.Tensor): A scalar tensor containing the LM loss.
                - intermediate_outputs (BlipIntermediateOutput): A BlipIntermediateOutput object containing intermediate outputs.
                  see :class:`lavis.models.blip_models.blip_outputs.BlipOutput` for more details.

        Example:
        ```python
        >>> from PIL import Image
        >>> from lavis.models import load_model_and_preprocess
        >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption")
        >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
        >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
        >>> text_input = ["a large statue of a person spraying water from a fountain"]
        >>> samples = {"image": image, "text_input": text_input}
        >>> output = model(samples)
        >>> output.keys()
        odict_keys(['intermediate_output', 'loss', 'loss_lm'])
        >>> output.intermediate_output.image_embeds.shape
        torch.Size([1, 577, 768])
        >>> output.intermediate_output.decoder_labels.shape
        torch.Size([1, 13])
        ```"""

        image_embeds = self.forward_encoder(image)
        if self.VLBridge != None:
            image_embeds = self.VLBridge(image_embeds)
        decoder_output, decoder_targets = self.forward_decoder(text, image_embeds)

        # decoder_out
        # BlipOutput(
        #     loss=decoder_output.loss,
        #     loss_lm=decoder_output.loss,
        #     intermediate_output=BlipIntermediateOutput(
        #         image_embeds=image_embeds,
        #         decoder_output=decoder_output,
        #         decoder_labels=decoder_targets,
        #     ),
        # )

        feature_fusion = self.feedForwardLinear(decoder_output.last_hidden_state)
        feature_fusion = self.feedForwardLayerNorm(feature_fusion + feature_fusion)
        feature_fusion = feature_fusion.squeeze(-1)
        feature_fusion = feature_fusion.transpose(0, 1)
        ####################### gemma  generate #######################
        last_hidden_state = self.gemmaGenerate(feature_fusion)
        output_text = self.gemmaLm_head(last_hidden_state)
        ###############################################################

        ######################### funny score #########################
        output_funny_score = self.FunnyScorelinear1(feature_fusion).squeeze(-1)
        output_funny_score = self.FunnyScorelinear2(output_funny_score).squeeze(-1)
        ###############################################################

        return output_text, output_funny_score

    def generate(
        self,
        image,
        use_nucleus_sampling=False,
        num_beams=3,
        max_length=100,
        min_length=10,
        top_p=0.9,
        repetition_penalty=1.0,
        num_captions=1,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.

        Example:
        ```python
        >>> from PIL import Image
        >>> from lavis.models import load_model_and_preprocess
        >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption")
        >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
        >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
        >>> samples = {"image": image}
        >>> captions = model.generate(samples)
        >>> captions
        ['a large statue of a person spraying water from a fountain']
        >>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3)
        >>> captions # example output, results may vary due to randomness
        ['singapore showing the view of some building',
        'the singapore harbor in twilight, as the weather is going down',
        'the famous singapore fountain at sunset']
        """
        # 有時後空格會失效，所以手動插入空格 <pad> = 0
        def insert_zeros(list):
            zeros = [0] * (2 * len(list) - 1)
            zeros[::2] = list
            return zeros


        tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
        generated_tokens = []
        generated_tokens.append(tokenizer.bos_token_id) #<bos> = 2
        text = torch.zeros_like(image).to(device)
        lastTurn = False
        with torch.no_grad():
            for _ in range(max_length + 1):
                # prepare inputs for decoder generation.
                encoder_out = self.forward_encoder(image)
                if self.VLBridge != None:
                    encoder_out = self.VLBridge(encoder_out)
                image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0)

                # prompt = [self.prompt] * image_embeds.size(0)
                # prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device)
                # prompt.input_ids[:, 0] = self.tokenizer.bos_token_id
                # prompt.input_ids = prompt.input_ids[:, :-1]

                # get decoded text
                decoder_out = self.text_decoder.generate_from_encoder(
                    # tokenized_prompt=prompt,
                    visual_embeds=image_embeds,
                    sep_token_id=self.tokenizer.sep_token_id,
                    pad_token_id=self.tokenizer.pad_token_id,
                    use_nucleus_sampling=use_nucleus_sampling,
                    num_beams=num_beams,
                    max_length=max_length,
                    min_length=min_length,
                    top_p=top_p,
                    repetition_penalty=repetition_penalty,
                )

                feature_fusion = self.feedForwardLinear(decoder_out)
                feature_fusion = self.feedForwardLayerNorm(feature_fusion + feature_fusion)
                feature_fusion = feature_fusion.squeeze(-1)
                feature_fusion = feature_fusion.transpose(0, 1)

                # gemma generate
                last_hidden_state = self.gemmaGenerate(feature_fusion)
                output_text = self.gemmaLm_head(last_hidden_state)

                # funny score
                output_funny_score = self.FunnyScorelinear1(feature_fusion).squeeze(-1)
                output_funny_score = self.FunnyScorelinear2(output_funny_score).squeeze(-1)

                if lastTurn: # show final funny score
                    return generated_caption, output_funny_score
                else:
                    next_token_logits = output_text[:, -1, :]
                    next_token_probs = torch.softmax(next_token_logits, dim=-1)
                    next_token_id = torch.argmax(next_token_probs, dim=-1).item()
                    generated_tokens.append(next_token_id)

                    generated_caption = insert_zeros(generated_tokens)
                    generated_caption = tokenizer.decode(generated_caption, skip_special_tokens=False)
                    generated_caption = generated_caption.replace("<pad>", " ").replace("  ", " ").split()
                    generated_caption = [word for word in generated_caption if word[0] != "<"]
                    generated_caption = " ".join(generated_caption)

                    text = textExtraction([generated_caption]).to(device)
                    # text = text.transpose(0, 1)

                    if next_token_id in gemmaConfig.eos_token_id or len(generated_caption.split()) > max_length:
                        #<eos> = 1; <end_of_turn> = 107
                        lastTurn = True
                # outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)
                # captions = [output[len(self.prompt) :] for output in outputs]

                # return captions

    @classmethod
    def from_config(cls, cfg):
        # vision encoder
        image_encoder = VisionTransformerEncoder.from_config(cfg)
        # text encoder + multimodal decoder
        text_decoder = XBertLMHeadDecoder.from_config(cfg)

        prompt = cfg.get("prompt", None)
        max_txt_len = cfg.get("max_txt_len", 40)

        model = cls(image_encoder, text_decoder, prompt=prompt, max_txt_len=max_txt_len)
        model.load_checkpoint_from_config(cfg)

        return model

# 3. Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Generator
        self.g_linearFake = nn.Linear(256000, 768)
        self.g_con_mlp1 = nn.Linear(768, 2)
        self.g_con_mlp2 = nn.Linear(128, 1)
        self.g_unc_mlp1 = nn.Linear(768, 1)
        self.g_unc_mlp2 = nn.Linear(64, 1)
        # Discriminator
        self.d_linearFake = nn.Linear(gemmaConfig.vocab_size, 768)
        self.d_con_mlp1_r2f = nn.Linear(768, 2)
        self.d_con_mlp2_r2f = nn.Linear(256, 1)
        self.d_con_mlp1_f2r = nn.Linear(768, 2)
        self.d_con_mlp2_f2r = nn.Linear(256, 1)
        self.d_con_mlp1_g = nn.Linear(768, 2)
        self.d_con_mlp2_g = nn.Linear(128, 1)
        self.d_con_mlp1_m = nn.Linear(768, 2)
        self.d_con_mlp2_m = nn.Linear(128, 1)
        self.d_unc_mlp1_r = nn.Linear(768, 1)
        self.d_unc_mlp2_r = nn.Linear(64, 1)
        self.d_unc_mlp1_g = nn.Linear(768, 1)
        self.d_unc_mlp2_g = nn.Linear(64, 1)
        self.d_unc_mlp1_m = nn.Linear(768, 1)
        self.d_unc_mlp2_m = nn.Linear(64, 1)

    def forward(self, real_text, fake_text, image):
        # real_text = [batch_size, 64, 768]
        # fake_text = [batch_size, 256, 256000]
        # image = [batch_size, 64, 768]
        g_fake_text = self.g_linearFake(fake_text)

        d_fake_text = self.d_linearFake(fake_text)
        mismatched_text = torch.roll(real_text, 1, 0)

        # conditional (contrastive)
        C_r = torch.cat((real_text, image), dim=1)
        g_C_g = torch.cat((g_fake_text, image), dim=1)
        d_C_g = torch.cat((d_fake_text, image), dim=1)
        C_m = torch.cat((mismatched_text, image), dim=1)
        # contrastive discriminator
        d_C_r2f = torch.cat((C_r, d_C_g), dim=1)
        d_C_f2r = torch.cat((d_C_g, C_r), dim=1)
        ########################## Generator ##########################
        g_C_g = self.g_con_mlp1(g_C_g)
        g_C_g = self.g_con_mlp2(g_C_g.transpose(1, 2)).squeeze(-1)
        ###############################################################

        ######################## Discriminator ########################
        d_C_r2f = self.d_con_mlp1_r2f(d_C_r2f)
        d_C_f2r = self.d_con_mlp1_f2r(d_C_f2r)
        d_C_g = self.d_con_mlp1_g(d_C_g)
        d_C_m = self.d_con_mlp1_m(C_m)
        d_C_r2f = self.d_con_mlp2_r2f(d_C_r2f.transpose(1, 2)).squeeze(-1).unsqueeze(0)
        d_C_f2r = self.d_con_mlp2_r2f(d_C_f2r.transpose(1, 2)).squeeze(-1).unsqueeze(0)
        d_C_g = self.d_con_mlp2_g(d_C_g.transpose(1, 2)).squeeze(-1).unsqueeze(0)
        d_C_m = self.d_con_mlp2_m(d_C_m.transpose(1, 2)).squeeze(-1).unsqueeze(0)
        d_con_output = torch.cat((d_C_r2f, d_C_f2r, d_C_g, d_C_m), dim=0)
        ###############################################################

        #### unconditional ####
        ########################## Generator ##########################
        g_UC_g = self.g_unc_mlp1(g_fake_text).squeeze(-1)
        g_UC_g = self.g_unc_mlp2(g_UC_g).squeeze(-1)
        ###############################################################

        ######################## Discriminator ########################
        d_UC_r = self.d_unc_mlp1_r(real_text).squeeze(-1)
        d_UC_g = self.d_unc_mlp1_g(d_fake_text).squeeze(-1)
        d_UC_m = self.d_unc_mlp1_m(mismatched_text).squeeze(-1)
        d_UC_r = self.d_unc_mlp2_r(d_UC_r).squeeze(-1).unsqueeze(0)
        d_UC_g = self.d_unc_mlp2_g(d_UC_g).squeeze(-1).unsqueeze(0)
        d_UC_m = self.d_unc_mlp2_m(d_UC_m).squeeze(-1).unsqueeze(0)
        d_unc_output = torch.cat((d_UC_r, d_UC_g, d_UC_m), dim=0)
        ###############################################################
        # torch.Size([3, 32, 1])
        return g_C_g, g_UC_g, d_con_output, d_unc_output

In [None]:
# empty cuda memory
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NetG = BlipCaption().to(device)
NetD = Discriminator().to(device)
optimizer_G = optim.Adam(NetG.parameters(), lr=0.001)
optimizer_D = optim.Adam(NetD.parameters(), lr=0.001)
train_losses_FC = []
train_losses_G = []
train_losses_D = []
test_losses_FC = []
test_losses_G = []
test_losses_D = []
save = []
present_epoch = 1
best_train_loss_FC = 9999
best_train_loss_G = 9999
best_train_loss_D = 9999
best_test_loss_FC = 9999
best_test_loss_G = 9999
best_test_loss_D = 9999
loss_data = pd.DataFrame()

checkpoint = False
if checkpoint:
    checkpoint_G = torch.load('./AdpaterModel/test_save/test_save_2NetG.pth')
    checkpoint_D = torch.load('./AdpaterModel/test_save/test_save_2NetD.pth')
    NetG.load_state_dict(checkpoint_G['model_state_dict'])
    NetD.load_state_dict(checkpoint_D['model_state_dict'])
    optimizer_G.load_state_dict(checkpoint_G['optimizer_state_dict'])
    optimizer_D.load_state_dict(checkpoint_D['optimizer_state_dict'])
    train_losses_FC.append(checkpoint_G['FC_loss'])
    train_losses_G.append(checkpoint_G['G_loss'])
    train_losses_D.append(checkpoint_G['D_loss'])
    present_epoch = checkpoint_G['epoch'] + 1



funnyScoreLoss = nn.MSELoss()

def generatorLoss(condition_logits, uncondition_logits):
    result_fake = (torch.zeros(uncondition_logits.shape[0])).to(device)
    con_loss = CrossEntropyLoss()(condition_logits, result_fake.to(torch.long))
    unc_loss = BCEWithLogitsLoss()(uncondition_logits, result_fake)
    loss = con_loss + unc_loss
    return loss

def discriminatorLoss(condition_logits, uncondition_logits):
    result_true = (torch.ones(uncondition_logits[0].shape[0])).to(device)
    result_fake = (torch.zeros(uncondition_logits[0].shape[0])).to(device)

    con_r2f = CrossEntropyLoss()(condition_logits[0], result_fake.to(torch.long))
    con_f2r = CrossEntropyLoss()(condition_logits[1], result_fake.to(torch.long))
    con_f = CrossEntropyLoss()(condition_logits[2], result_fake.to(torch.long))
    con_m = CrossEntropyLoss()(condition_logits[3], result_fake.to(torch.long))
    unc_r = BCEWithLogitsLoss()(uncondition_logits[0], result_true)
    unc_f = BCEWithLogitsLoss()(uncondition_logits[1], result_fake)
    unc_m = BCEWithLogitsLoss()(uncondition_logits[2], result_fake)
    loss = ((con_r2f + con_f2r)/2) + ((con_f + con_m)/2) + unc_r + ((unc_f + unc_m)/2)
    return loss

In [None]:
save_name = '20241011'
if not os.path.exists('./AdpaterModel/'+save_name):
    os.makedirs('./AdpaterModel/'+save_name)

epochs = 2
torch.autograd.set_detect_anomaly(True)
for epoch in range(epochs):
    print("---------------------------------------- epoch "+ str(epoch + present_epoch) +" ---------------------------------------")
    train_loss_FC = 0
    train_loss_G = 0
    train_loss_D = 0
    test_loss_FC = 0
    test_loss_G = 0
    test_loss_D = 0

    ###################################### Train ######################################
    with tqdm(train_loader, unit="batch", leave=True) as tepoch:
        tepoch.set_postfix({'Now': " New batch preprocessing"})
        for text, image, funny_score in tepoch:
            text = text.squeeze(1)
            image = image.squeeze(1)
            # print(text.shape, image.shape, funny_score.shape)
            # torch.Size([32, 64, 768]) torch.Size([32, 64, 768]) torch.Size([32, 1])
            ######################################################
            # (1) Generate fake caption
            ######################################################
            tepoch.set_postfix({'Now': " Generating fake caption -> Generator"})
            logits, output_funny_score = NetG(text.to(device).to(torch.float32), image.to(device).to(torch.float32))
            tepoch.set_postfix({'Now': " Generating fake caption -> Discriminator"})
            g_con_logits, g_unc_logits, d_con_logits, d_unc_logits = NetD(text.to(device).to(torch.float32), logits.detach().to(torch.float32), image.to(device).to(torch.float32))
            ######################################################
            # (3) Update Discriminator network
            #####################################################
            tepoch.set_postfix({'Now': " Updating Discriminator network"})
            optimizer_D.zero_grad()
            loss_D = discriminatorLoss(d_con_logits, d_unc_logits)
            loss_D.backward(retain_graph=True)
            optimizer_D.step()
            train_loss_D += loss_D.item()
            ######################################################
            # (4) Update Generator network
            ######################################################
            tepoch.set_postfix({'Now': " Updating Generator network"})
            optimizer_G.zero_grad()
            loss_FC = funnyScoreLoss(output_funny_score, funny_score.to(device).to(torch.float32))
            loss_FC.backward(retain_graph=True)
            train_loss_FC += loss_FC.item()
            loss_G = generatorLoss(g_con_logits, g_unc_logits)
            loss_G.backward()
            optimizer_G.step()
            train_loss_G += loss_G.item()
            ######################################################
            tepoch.set_postfix({'Now': " New batch preprocessing"})
            ######################################################
    train_loss_FC /= len(train_loader)
    train_loss_G /= len(train_loader)
    train_loss_D /= len(train_loader)
    train_losses_FC.append(train_loss_FC)
    train_losses_G.append(train_loss_G)
    train_losses_D.append(train_loss_D)
    ###################################### Train ######################################


    ######################################  Test ######################################
    with tqdm(test_loader, unit="batch", leave=True) as tepoch:
        tepoch.set_postfix({'Now': " New batch preprocessing"})
        for text, image, funny_score in tepoch:
            text = text.squeeze(1)
            image = image.squeeze(1)
            # Generator
            tepoch.set_postfix({'Now': " Generating fake caption -> Generator"})
            logits, output_funny_score = NetG(text.to(device).to(torch.float32), image.to(device).to(torch.float32))
            # Discriminator
            tepoch.set_postfix({'Now': " Generating fake caption -> Discriminator"})
            g_con_logits, g_unc_logits, d_con_logits, d_unc_logits = NetD(text.to(device).to(torch.float32), logits.detach().to(torch.float32), image.to(device).to(torch.float32))
            # loss
            tepoch.set_postfix({'Now': " Computing loss"})
            loss_FC = funnyScoreLoss(output_funny_score, funny_score.to(device).to(torch.float32))
            loss_G = generatorLoss(g_con_logits, g_unc_logits)
            loss_D = discriminatorLoss(d_con_logits, d_unc_logits)
            test_loss_FC += loss_FC.item()
            test_loss_G += loss_G.item()
            test_loss_D += loss_D.item()
    test_loss_FC /= len(test_loader)
    test_loss_G /= len(test_loader)
    test_loss_D /= len(test_loader)
    test_losses_FC.append(test_loss_FC)
    test_losses_G.append(test_loss_G)
    test_losses_D.append(test_loss_D)
    ######################################  Test ######################################

    ######################################  Save ######################################
    hasSaved = False
    # 任一個loss小於最佳loss就存檔
    if train_loss_FC < best_train_loss_FC and test_loss_FC < best_test_loss_FC:
        best_train_loss_FC = train_loss_FC
        best_test_loss_FC = test_loss_FC
        hasSaved = True
        torch.save({
            'epoch': epoch + present_epoch,
            'model_state_dict': NetG.state_dict(),
            'optimizer_state_dict': optimizer_G.state_dict(),
            'FC_loss': loss_FC,
            'G_loss': loss_G,
            'D_loss': loss_D,
        }, './AdpaterModel/' + save_name + "/" + save_name + '_NetG_'+ str(epoch + present_epoch) +'.pth')
        torch.save({
            'epoch': epoch + present_epoch,
            'model_state_dict': NetD.state_dict(),
            'optimizer_state_dict': optimizer_D.state_dict(),
            'FC_loss': loss_FC,
            'G_loss': loss_G,
            'D_loss': loss_D,
        }, './AdpaterModel/' + save_name + "/" + save_name + '_NetD_'+ str(epoch + present_epoch) +'.pth')
    if train_loss_G < best_train_loss_G and test_loss_G < best_test_loss_G:
        best_train_loss_G = train_loss_G
        best_test_loss_G = test_loss_G
        hasSaved = True
        torch.save({
            'epoch': epoch + present_epoch,
            'model_state_dict': NetG.state_dict(),
            'optimizer_state_dict': optimizer_G.state_dict(),
            'FC_loss': loss_FC,
            'G_loss': loss_G,
            'D_loss': loss_D,
        }, './AdpaterModel/' + save_name + "/" + save_name + '_NetG_'+ str(epoch + present_epoch) +'.pth')
        torch.save({
            'epoch': epoch + present_epoch,
            'model_state_dict': NetD.state_dict(),
            'optimizer_state_dict': optimizer_D.state_dict(),
            'FC_loss': loss_FC,
            'G_loss': loss_G,
            'D_loss': loss_D,
        }, './AdpaterModel/' + save_name + "/" + save_name + '_NetD_'+ str(epoch + present_epoch) +'.pth')
    if train_loss_D < best_train_loss_D and test_loss_D < best_test_loss_D:
        best_train_loss_D = train_loss_D
        best_test_loss_D = test_loss_D
        hasSaved = True
        torch.save({
            'epoch': epoch + present_epoch,
            'model_state_dict': NetG.state_dict(),
            'optimizer_state_dict': optimizer_G.state_dict(),
            'FC_loss': loss_FC,
            'G_loss': loss_G,
            'D_loss': loss_D,
        }, './AdpaterModel/' + save_name + "/" + save_name + '_NetG_'+ str(epoch + present_epoch) +'.pth')
        torch.save({
            'epoch': epoch + present_epoch,
            'model_state_dict': NetD.state_dict(),
            'optimizer_state_dict': optimizer_D.state_dict(),
            'FC_loss': loss_FC,
            'G_loss': loss_G,
            'D_loss': loss_D,
        }, './AdpaterModel/' + save_name + "/" + save_name + '_NetD_'+ str(epoch + present_epoch) +'.pth')

    if hasSaved:
        save.append("V")
    else:
        save.append(" ")

    loss_data['train_FC'] = train_losses_FC
    loss_data['train_G'] = train_losses_G
    loss_data['train_D'] = train_losses_D
    loss_data['test_FC'] = test_losses_FC
    loss_data['test_G'] = test_losses_G
    loss_data['test_D'] = test_losses_D
    loss_data['save'] = save
    loss_data.to_csv('./AdpaterModel/' + save_name + "/" + save_name + '_loss.csv', index=False)
    ######################################  Save ######################################


In [None]:
# @register_model

In [None]:
list = timm.list_models("*swin*")
print(list)

In [None]:
# save my model on timm
timm.model_entrypoint()