In [None]:

import os
import torch
from torch import optim, nn
import torch.nn.functional as F
import torchvision.models as models
import pytorch_lightning as pl
from torchmetrics import Accuracy
import pandas as pd
import os
from torch.utils.data import Dataset, DataLoader
from TALENT.model.utils import get_deep_args,get_method

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, dim_q, dim_kv, dim_out, num_heads=4, dropout=0.1):
        super().__init__()
        self.attn = (nn.MultiheadAttention
                     (embed_dim=dim_out, num_heads=num_heads, dropout=dropout, batch_first=True))
        self.proj_q = nn.Linear(dim_q, dim_out)
        self.proj_kv = nn.Linear(dim_kv, dim_out)

    def forward(self, q, kv):
        q_proj = self.proj_q(q)
        kv_proj = self.proj_kv(kv)
        attn_output, _ = self.attn(q_proj.unsqueeze(1), kv_proj.unsqueeze(1), kv_proj.unsqueeze(1))
        return attn_output.squeeze(1)

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self, model_name: str = 'resnet', in_dims: int = 2048,
                 out_dims: int = 129,
                 n_num_features: int = 0,
                 cat_cardinalities: list = [],
                 d_token: int = 8):
        super().__init__()

        self.model_name = model_name
        if model_name == "resnet":
            backbone = models.resnet50(weights=models.resnet.ResNet50_Weights.IMAGENET1K_V1)
            in_dims = backbone.fc.in_features
            img_fc = nn.Sequential(
                nn.Linear(in_dims, 1024),
                nn.Linear(1024, out_dims)
            )
            backbone.fc = Identity()
        elif model_name == "densenet":
            backbone = models.densenet121(weights=models.densenet.DenseNet121_Weights.IMAGENET1K_V1)
            in_dims = backbone.classifier.in_features
            img_fc = nn.Sequential(
                nn.Linear(in_dims, 1024),
                nn.Linear(1024, out_dims)
            )
            backbone.classifier = Identity()
        elif model_name == "inception":
            backbone = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
            in_dims = backbone.fc.in_features
            img_fc = nn.Sequential(
                nn.Linear(in_dims, 1024),
                nn.Linear(1024, out_dims)
            )
            backbone.fc = Identity()
        elif model_name == "mobilenet":
            backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
            in_dims = backbone.classifier[-1].in_features
            img_fc = nn.Sequential(
                nn.Linear(in_dims, 1024),
                nn.Linear(1024, out_dims)
            )
            backbone.classifier[-1] = Identity()

        self.in_dims = in_dims
        self.table_dim = n_num_features + len(cat_cardinalities)

        self.con_fc_num = n_num_features
        self.cat_fc_num = len(cat_cardinalities)
        self.tab_model = self.get_ftt_from_talent()

        self.backbone = backbone
        self.img_fc = img_fc
        self.device = None

        tab_embd_dim = d_token * n_num_features + len(cat_cardinalities) * d_token
        cross_attn_out_dim = tab_embd_dim

        self.img_to_tab_cross_attn = CrossAttention(dim_q=in_dims, dim_kv=tab_embd_dim, dim_out=cross_attn_out_dim)
        self.tab_to_img_cross_attn = CrossAttention(dim_q=tab_embd_dim, dim_kv=in_dims, dim_out=cross_attn_out_dim)
        self.img_to_tab_dim = nn.Linear(in_dims, tab_embd_dim)

        self.cross_dim = self.table_dim * d_token
        self.image_cross_proj = nn.Linear(in_dims, self.cross_dim)
        # self.table_cross_proj = nn.Linear(self.table_dim, self.cross_dim)

        self.concat_fc = nn.Sequential(
            nn.Linear(tab_embd_dim * 3 + in_dims, out_dims),
        )

    def get_ftt_from_talent(self):
        args, _, _ = get_deep_args()
        args.model_type = "ftt"
        method = get_method(args.model_type)(args, True)
        return method.model
    
    def InfoMax_loss(self, x1, x2):
        x1 = x1 / (torch.norm(x1, p=2, dim=1, keepdim=True) + 1e-10)
        x2 = x2 / (torch.norm(x2, p=2, dim=1, keepdim=True) + 1e-10)
        bs = x1.size(0)
        s = torch.matmul(x1, x2.permute(1, 0))
        mask_joint = torch.eye(bs).cuda()
        mask_marginal = 1 - mask_joint

        Ej = (s * mask_joint).mean()
        Em = torch.exp(s * mask_marginal).mean()
        # decoupled comtrastive learning?!!!!
        # infomax_loss = - (Ej - torch.log(Em)) * self.alpha
        infomax_loss = - (Ej - torch.log(Em))  # / Em
        return infomax_loss

    def InfoMin_loss(self, cross_img_feats, cross_tab_feats):
        """
        最小化 cross_img_feats 和 cross_tab_feats 的互信息。
        使用 KL 散度方法：
        I(cross_img_feats, cross_tab_feats) ≤ D_KL(p(cross_img_feats | cross_tab_feats) || p(cross_img_feats)p(cross_tab_feats))
        Args:
            cross_img_feats: 从图像模态计算的交互特征 (batch_size, feature_dim)
            cross_tab_feats: 从表格模态计算的交互特征 (batch_size, feature_dim)
        Returns:
            loss: 互信息最小化的损失 (标量)
        """
        # 计算 cross_img_feats 和 cross_tab_feats 的均值和对数方差
        mu_img = cross_img_feats.mean(dim=0)
        cross_img_feats = cross_img_feats.mean(dim=0)
        log_var_img = torch.log(cross_img_feats.var(dim=0) + 1e-10)

        mu_tab = cross_tab_feats.mean(dim=0)
        cross_tab_feats = cross_tab_feats.mean(dim=0)
        log_var_tab = torch.log(cross_tab_feats.var(dim=0) + 1e-10)

        mu_img = mu_img.mean(dim=0)
        mu_tab = mu_tab.mean(dim=0)

        # 计算 KL 散度
        var_img = torch.exp(log_var_img)
        var_tab = torch.exp(log_var_tab)

        kl_img = 0.5 * torch.sum(1 + log_var_img - mu_img ** 2 - var_img)
        kl_tab = 0.5 * torch.sum(1 + log_var_tab - mu_tab ** 2 - var_tab)
        # print("mu_img:", mu_img)
        # print("log_var_img:", log_var_img)
        # print("KL is ",(kl_img + kl_tab) / 2)
        # 最小化 KL 散度
        return -(kl_img + kl_tab) / 2

    def forward(self, img, tab_con, tab_cat):
        self.device = img.device

        if self.con_fc_num == 0:
            tab_con = None
        if self.cat_fc_num == 0:
            tab_cat = None

        extracted_feats = self.backbone(img)
        # (batch_size, hidden_dim, seq_len)

        img_out = self.img_fc(extracted_feats)

        table_features_embed, table_embed_out = self.tab_model(tab_con, tab_cat)


        # cross_tab_feats = self.img_to_tab_cross_attn(extracted_feats, table_features_embed)
        # cross_img_feats = self.tab_to_img_cross_attn(table_features_embed, extracted_feats)
        # table_features_embed_permute = table_features_embed.permute(0, 2, 1)  # (batch_size, 8, in_dim)
        table_features_embed_permute = table_features_embed.view(table_features_embed.size(0), -1)
        table_features_embed_permute = table_features_embed_permute.unsqueeze(1)
        table_features = table_features_embed_permute  # (batch_size, 8, target_dim)
        query_image = self.image_cross_proj(extracted_feats.unsqueeze(1))  # (batch_size, 1, target_dim)
        query_table = table_features

        scores_img2tab = torch.matmul(query_image, query_table.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.cross_dim))  # (batch_size, 1, 1)
        attn_weights_img2tab = F.softmax(scores_img2tab, dim=-1)
        cross_tab_feats = torch.matmul(attn_weights_img2tab, query_table).squeeze(dim=1)  # (batch_size, target_dim)

        scores_tab2img = torch.matmul(query_table, query_image.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.cross_dim))
        attn_weights_tab2img = F.softmax(scores_tab2img, dim=-1)
        cross_img_feats = torch.matmul(attn_weights_tab2img, query_image).squeeze(dim=1)

        table_features_embed = table_features_embed.view(table_features_embed.size(0), -1)

        concatenated_features = torch.cat(
            [extracted_feats, cross_img_feats, table_features_embed,
             cross_tab_feats],
            dim=1
        )
        concat_out = self.concat_fc(concatenated_features)

        return img_out, table_embed_out, concat_out, extracted_feats, table_features_embed, cross_img_feats, cross_tab_feats

In [22]:
class ImageModelDVMWithTab(pl.LightningModule):
    def __init__(self, model_name, n_num_features, cat_cardinalities, reverse=False):
        super().__init__()
        self.net_img_clf = ImageClassifier(model_name=model_name, n_num_features=n_num_features,
                                           cat_cardinalities=cat_cardinalities)
        self.test_acc = Accuracy(task="multiclass", num_classes=129)
        self.reverse = reverse
        self.valid_loader = self.val_dataloader()
        self.model_name = model_name

    def val_dataloader(self):
        valid_dataset = DVMConCatImageDataset("dataset/DVM/dataset_valid.csv", "raw_dataset/DVM")
        valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=8, shuffle=False)

        return valid_loader

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward

        (table_features_con, table_features_cat, image_features), label = batch
        img_out, table_embed_out, concat_out, extracted_feats, table_features_embed, cross_img_feats, cross_tab_feats = self.net_img_clf(image_features, table_features_con, table_features_cat)

        img_loss = F.cross_entropy(img_out, label)
        table_embed_loss = F.cross_entropy(table_embed_out, label)
        concat_loss = F.cross_entropy(concat_out, label)
        # loss = 1 * img_loss + 0.2 * table_embed_loss + 4 * concat_loss
        loss = 1 * img_loss + 4 * concat_loss

        cross_loss = self.net_img_clf.InfoMin_loss(extracted_feats, table_features_embed)
        proj_loss = self.net_img_clf.InfoMax_loss(cross_img_feats, cross_tab_feats)

        loss = loss - 0.2 * proj_loss + 40 * cross_loss

        self.log("img_loss", img_loss)
        self.log("tab_embed_loss", table_embed_loss)
        self.log("concat_loss", concat_loss)
        self.log("proj_loss", proj_loss)
        self.log("cross_loss", cross_loss)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        (table_features_con, table_features_cat, image_features), label = batch
        img_out, table_embed_out, concat_out, extracted_feats, table_features_embed, cross_img_feats, cross_tab_feats = self.net_img_clf(image_features, table_features_con, table_features_cat)

        img_loss = F.cross_entropy(img_out, label)
        table_embed_loss = F.cross_entropy(table_embed_out, label)
        concat_loss = F.cross_entropy(concat_out, label)
        # loss = 1 * img_loss + 0.2 * table_embed_loss + 4 * concat_loss
        loss = 1 * img_loss + 4 * concat_loss

        cross_loss = self.net_img_clf.InfoMin_loss(extracted_feats, table_features_embed)
        proj_loss = self.net_img_clf.InfoMax_loss(cross_img_feats, cross_tab_feats)

        loss = loss - 0.2 * proj_loss + 40 * cross_loss

        val_acc = self.test_acc(concat_out, label).item()
        val_acc_img = self.test_acc(img_out, label).item()
        val_acc_table = self.test_acc(table_embed_out, label).item()

        self.log("val_img_loss", img_loss)
        self.log("val_tab_embed_loss", table_embed_loss)
        self.log("val_concat_loss", concat_loss)
        self.log("val_proj_loss", proj_loss)
        self.log("val_cross_loss", cross_loss)
        self.log("val_acc", val_acc, on_step=False, on_epoch=True)
        self.log("val_acc_img", val_acc_img, on_step=False, on_epoch=True)
        self.log("val_acc_tab", val_acc_table, on_step=False, on_epoch=True)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        (table_features_con, table_features_cat, image_features), label = batch
        img_out, table_embed_out, concat_out, _, _, _, _ = self.net_img_clf(image_features, table_features_con, table_features_cat)
        loss = F.cross_entropy(concat_out, label)
        return {
            "loss": loss,
            "preds": concat_out.detach(),
            "img_preds": img_out.detach(),
            "table_preds": table_embed_out.detach(),
            "y": label.detach()
        }
    def test_step_end(self, outputs):
        test_acc = self.test_acc(outputs['preds'], outputs['y']).item()
        self.log("test_acc", test_acc, on_epoch=True, on_step=False)
        img_test_acc = self.test_acc(outputs['img_preds'], outputs['y']).item()
        self.log("test_acc_img", img_test_acc, on_epoch=True, on_step=False)
        table_test_acc = self.test_acc(outputs['table_preds'], outputs['y']).item()
        self.log("test_acc_tab", table_test_acc, on_epoch=True, on_step=False)
        self.log("test_loss", outputs["loss"].mean(), on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = optim.SGD(self.net_img_clf.parameters(), lr=1e-2, weight_decay=1e-4)
        return optimizer

In [17]:
from torchvision import transforms
class DVMConCatImageDataset(Dataset):
    def __init__(self, table_path: str, image_path: str):
        self.table_path = table_path
        self.image_path = image_path

        image_size = 224
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomHorizontalFlip(),
            
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])])

        self.table_df = pd.read_csv(self.table_path)
        self.con_cols = ['Width', 'Length', 'Height', 'Wheelbase', 'Price', 'Adv_year', 'Reg_year', 
                 'Adv_month', 'Runned_Miles', 'Seat_num', 'Door_num', 'Engin_size', 'Entry_price']
        self.cat_cols = ['Bodytype', 'Gearbox', 'Fuel_type']
        for col in self.cat_cols:
            if self.table_df[col].max() > 100:
                self.table_df[col] = pd.cut(self.table_df[col], bins=10, labels=range(10))
        for col in self.con_cols:
            # if self.table_df[col].max() > 100:
            #     self.table_df[col] = pd.cut(self.table_df[col], bins=100, labels=range(100))
            # z-score normalization
            col_data = torch.tensor(self.table_df[col].values, dtype=torch.float)
            mean = torch.mean(col_data)
            std = torch.std(col_data)
            self.table_df[col] = torch.div(torch.sub(col_data, mean), std)
        self.cat_cardinalities = self.table_df[self.cat_cols].max().tolist()
        self.cat_cardinalities = [i+1 for i in self.cat_cardinalities]

        self.data_indice = []

        table_y = self.table_df["Genmodel_ID_encode"]
        table_X = self.table_df.drop(['Genmodel_ID_encode'], axis=1)

        for index, row in tqdm.tqdm(table_X.iterrows(), total=len(table_X)):
            label = table_y[index]
            image_id = row["Image_path"]

            features = row.drop("Image_path")
            features_con = features[self.con_cols].to_numpy(dtype=float)
            features_cat = features[self.cat_cols].to_numpy(dtype=int)
            self.data_indice.append((image_id, features_con, features_cat, label))
    
    def __getitem__(self, index):
        image_id, features_con, features_cat, label = self.data_indice[index]
        image_path = os.path.join(self.image_path, f"{image_id}")
        image = self.transform(Image.open(image_path).convert('RGB'))

        table_features_con = torch.tensor(features_con, dtype=torch.float)
        table_features_cat = torch.tensor(features_cat, dtype=torch.long)
        
        label_tensor = torch.tensor(label, dtype=torch.long)

        return (table_features_con, table_features_cat, image), label_tensor

    def __len__(self):
        return len(self.data_indice)

In [None]:
train_dataset = DVMConCatImageDataset("dataset/DVM/dataset_train.csv", "raw_dataset/DVM")
valid_dataset = DVMConCatImageDataset("dataset/DVM/dataset_valid.csv", "raw_dataset/DVM")
test_dataset = DVMConCatImageDataset("dataset/DVM/dataset_test.csv", "raw_dataset/DVM")

train_loader = DataLoader(train_dataset, batch_size=64, num_workers=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=8)

n_num_features = len(train_dataset.con_cols)
cat_cardinalities = train_dataset.cat_cardinalities

model = ImageModelDVMWithTab(model_name='resnet', n_num_features=n_num_features, cat_cardinalities=cat_cardinalities)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath=None, save_last=True, save_weights_only=False,
    monitor="val_loss", mode="min", save_top_k=6
)
early_stopping = pl.callbacks.EarlyStopping(monitor='val_loss', patience=9, mode='min')
trainer = pl.Trainer(max_epochs=99, accelerator="gpu", devices=device, callbacks=[checkpoint_callback, early_stopping], logger=logger)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)

trainer.test(model=model, dataloaders=test_loader, ckpt_path='best')