In [1]:
import os

labels = []
with open(r"E:\Python test Work\ConAno\labels\UCF_test.txt","r") as f:
    for line in f:
        if "Normal" in line.strip():
            labels.append(0)
        else:
            labels.append(1)

In [2]:
import pickle

import torch.optim
import os
from utils.utils import *
from utils.evaluation import *
from utils.model_utils import *
import numpy as np
from config import CFG
from torch.utils.data import DataLoader
from models.modules import PositionalEncoding1D
from Loss.losses import *
from dataset import MyDataset
import warnings
from utils.evaluation import evaluate_result
import random

warnings.filterwarnings("ignore")

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

log_theta = torch.nn.LogSigmoid()


def get_dataloader(args: CFG):
    data_train = torch.from_numpy(np.load(args.train_path)).reshape(-1, args.snippets, 1024)

    label_train = torch.load(args.label_train_path)
    dataset_train = MyDataset(data_train, label_train)

    train_loader = DataLoader(dataset_train, batch_size=CFG.Batch_size, shuffle=True)

    data_test = torch.from_numpy(np.load(args.test_path)).reshape(-1, args.snippets, 1024)

    label_test = None
    dataset_test = MyDataset(data_test, label_test, mode='test')
    test_loader = DataLoader(dataset_test, batch_size=CFG.Batch_size, shuffle=False)

    return train_loader, test_loader


def train_meta_epoch(args: CFG, epoch, trainloader, normalizing_flow, optimizer, POS_EMB: PositionalEncoding1D,
                     metric_recoder: MetricRecoder):
    """
    :param args:
    :param epoch:
    :param trainloader: [Batch_size , 32,1024]
    :param normalizing_flow:
    :param optimizer:
    :return:
    """
    normalizing_flow.to(args.device)
    normalizing_flow.train()
    adjust_learning_rate(args, optimizer, epoch)
    I = len(trainloader)
    for sub_epoch in range(args.sub_epochs):
        total_loss, loss_count = 0.0, 0
        logps_list = []
        for (i, loader) in enumerate(trainloader):
            # lr = warmup_learning_rate(args, epoch, i + sub_epoch * I, I * args.sub_epochs, optimizer)
            loader = loader.to(args.device)  # [4,32,1024]

            m_b = torch.hstack([torch.zeros(loader.shape[1] // 2), torch.ones(loader.shape[1] // 2)]).unsqueeze(
                0).repeat(args.Batch_size, 1).to(args.device)  # [4,32]

            if epoch == 0 or sub_epoch < args.normal_sub_epoch:  # Normal only
                loader = loader[:, :args.snippets, :]
                m_b = m_b[:, :args.snippets]
            b, n, c = loader.shape

            loader = loader.reshape(-1, c)  # [Batch_size * 16(32) , 1024]
            m_b = m_b.reshape(-1)
            pos_embed = POS_EMB(loader.reshape(b, n, c)).reshape(-1, args.pos_embed_dim).to(
                args.device)  # [B,16(32),1024]

            perm = torch.randperm(b * n).to(args.device)
            e_b = loader[perm]
            m_b = m_b[perm]
            p_b = pos_embed[perm]

            if args.flow_arch == 'flow_model':
                z, log_jac_det = normalizing_flow(e_b)  # [4*16,1024] , [4*16]
            else:
                z, log_jac_det = normalizing_flow(e_b, [p_b, ])

            if epoch == 0:
                logps = get_logp(c, z, log_jac_det) / c  # [4*16]

                loss = -log_theta(logps).mean()
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                loss_count += 1
            else:
                if sub_epoch < args.normal_sub_epoch:
                    logps = get_logp(c, z, log_jac_det)  # Batch_size * 16
                    logps = logps / c
                    if args.focal_weighting:
                        normal_weights = normal_fl_weighting(logps.detach())

                        loss = -log_theta(logps) * normal_weights
                        loss = loss.mean()

                    else:
                        loss = -log_theta(logps).mean()
                else:
                    logps = get_logp(c, z, log_jac_det)

                    logps = logps / c
                    if args.focal_weighting:
                        logps_detach = logps.detach()
                        normal_logps = logps_detach[m_b == 0]
                        anomaly_logps = logps_detach[m_b == 1]
                        nor_weights = normal_fl_weighting(normal_logps)
                        ano_weights = abnormal_fl_weighting(anomaly_logps)
                        weights = nor_weights.new_zeros(logps_detach.shape)
                        weights[m_b == 0] = nor_weights
                        weights[m_b == 1] = ano_weights
                        loss_ml = -log_theta(logps[m_b == 0]) * nor_weights  # (256, )
                        loss_ml = torch.mean(loss_ml)
                    else:

                        loss_ml = -log_theta(logps[m_b == 0])
                        loss_ml = torch.mean(loss_ml)

                    boundaries = get_logp_boundary(logps, m_b, args.pos_beta, args.margin_abnormal_negative,
                                                   args.margin_abnormal_positive, args.normalizer)
                    # print(boundaries)  # b_n,b_a_negative,b_a_positive

                    if args.focal_weighting:
                        loss_n_con, loss_a_con_pos = calculate_bg_spp_loss(logps, m_b, boundaries,
                                                                           args.normalizer,
                                                                           weights=weights, mode=2)
                    else:
                        loss_n_con, loss_a_con_pos = calculate_bg_spp_loss(logps, m_b, boundaries,
                                                                           args.normalizer, mode=2)
                    # print(f"Loss_ml {loss_ml}, loss_n_con {loss_n_con}, loss_a_con {loss_a_con_pos}")
                    loss = loss_ml + args.bgspp_lambda * (loss_n_con + loss_a_con_pos)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_item = loss.item()
                if math.isnan(loss_item):
                    total_loss += 0.0
                    loss_count += 0
                else:
                    total_loss += loss.item()
                    loss_count += 1
            logps_list.append(logps)
        # print(f"Debug {total_loss} {loss_count}")
        metric_recoder.update(epoch=epoch, sub_epoch=sub_epoch, loss=total_loss / loss_count if loss_count != 0 else -1)
        print(metric_recoder)
    return logps_list


def validate(args: CFG, epoch, data_loader, normalizing_flow, POS_EMBED, metric_recoder: MetricRecoder):
    print("Compute loss and scores")
    normalizing_flow = normalizing_flow.eval()
    total_loss, loss_count = 0.0, 0
    logps_list = []

    with torch.no_grad():
        for i, feature in enumerate(data_loader):
            b, n, dim = feature.shape
            feature = feature.to(args.device).reshape(-1, 1024)
            pos_embed = POS_EMBED(feature.reshape(b, n, dim)).reshape(-1, args.pos_embed_dim).to(args.device)
            if args.flow_arch == 'flow_model':
                z, log_jac_det = normalizing_flow(feature)
            else:
                z, log_jac_det = normalizing_flow(feature, [pos_embed, ])
            logps = get_logp(dim, z, log_jac_det)

            logps = logps / dim
            loss = -log_theta(logps).mean()
            total_loss += loss.item()
            loss_count += 1
            logps_list.append(logps.reshape(b, n))

        mean_loss = total_loss / loss_count
        scores = convert_to_anomaly_scores(args, logps_list).detach().cpu().numpy()

        roc_auc = evaluate_result(scores, args.label_test_path)
        metric_recoder.update(loss=mean_loss, roc_auc=roc_auc, epoch=epoch)
        print(metric_recoder)
    return mean_loss, roc_auc,logps_list


def train(args: CFG):
    trainloader, test_loader = get_dataloader(args)  # Correct

    normalizing_flow = get_flow_model(args, 1024)  # Correct
    optimizer = torch.optim.Adam(normalizing_flow.parameters(), lr=args.lr)  # Correct
    pos_embed = PositionalEncoding1D(args.pos_embed_dim)  # Correct
    train_recoder = MetricRecoder(mode='train')  # Correct
    test_recoder = MetricRecoder(mode='test')  # Correct
    for epoch in range(args.num_epochs):
        logps_list = train_meta_epoch(args, epoch, trainloader, normalizing_flow, optimizer, pos_embed,
                                      metric_recoder=train_recoder)
        # if epoch == 1:
        #     logps_list = torch.concat(logps_list, dim=0).detach().cpu().numpy()
        #     print(logps_list.shape)
        #     exit()
        #     sns.histplot(logps_list, kde=True)
        #     plt.show()
        validate(args, epoch, test_loader, normalizing_flow, pos_embed, test_recoder)
        # exit()
    if args.save_result:
        torch.save(normalizing_flow.state_dict(), os.path.join(args.log_path, 'model.pt'))
        with open(os.path.join(args.log_path, 'train.pickle'), "wb") as f:
            pickle.dump(train_recoder, f)

        with open(os.path.join(args.log_path, 'test.pickle'), "wb") as f:
            pickle.dump(test_recoder, f)


# import seaborn as sns
# import matplotlib.pyplot as plt
#
#
def check_result(args, model, checkpoint):
    model.load_state_dict(torch.load(checkpoint))
    model.to(args.device)
    model.eval()
    _, test_loader = get_dataloader(args)
    pos_embed = PositionalEncoding1D(args.pos_embed_dim)
    test_recoder = MetricRecoder(mode='test')
    _, _, res = validate(args, 0, test_loader, model, pos_embed, test_recoder)

    scores = convert_to_anomaly_scores(args, res).detach().cpu().numpy()
#     evaluate_result(scores, args.label_test_path)
    return scores


# train_loader, test_loader = get_dataloader(args)
# print(len(train_loader))


In [3]:
args = CFG()
model = get_flow_model(args, 1024)
model2 = get_flow_model(args,1024)

# score = check_result(args, model, checkpoint)

Conditional Normalizing Flow => Feature Dimension:  1024
Conditional Normalizing Flow => Feature Dimension:  1024


In [4]:
checkpoint = os.path.join(args.log_path, 'model.pt')
checkpoint2 = os.path.join(args.log_path,'model_constrastive_new.pt')

In [6]:
# 1281,16,1024 => -1,1024 + Pos_embedding => Model => Z (-1,1024) => Normal Distribution

def get_score_two_models(model,model2,checkpoint1,checkpoint2,data_path):
    model.load_state_dict(torch.load(checkpoint1))
    model.train()
    model2.load_state_dict(torch.load(checkpoint2))
    model2.train()
    model = model.to('cuda')
    model2 = model2.to('cuda')
    data = torch.from_numpy(np.load(data_path)).reshape(-1,32,1024)
    data_loader = DataLoader(data,batch_size = 16,shuffle = False)
    with torch.no_grad():
        output = []
        output2 = []
        output_det = []
        output_det2 = []
        for d in data_loader:
            inp = d.to('cuda').reshape(-1,1024)
 
            POS = PositionalEncoding1D(args.pos_embed_dim).to('cuda')
            pos = POS(d).reshape(-1,args.pos_embed_dim).to('cuda')
            
            out = model(inp,[pos,])
            out2 = model2(inp,[pos,])
            output.append(out[0].reshape(-1,32,1024))
            output2.append(out2[0].reshape(-1,32,1024))
            output_det.append(out[1].reshape(-1,32))
            output_det2.append(out2[1].reshape(-1,32))

    output = torch.concat(output,dim=0).reshape(-1,1024)
    output_det = torch.concat(output_det,dim =0).reshape(-1)
    output2 = torch.concat(output2,dim=0).reshape(-1,1024)    
    output_det2 = torch.concat(output_det2,dim =0 ).reshape(-1)
    
    score = convert_to_anomaly_scores(args,get_logp(1024,output,output_det) / 1024)
    score2 = convert_to_anomaly_scores(args,get_logp(1024,output2,output_det2) / 1024)
    
    return score,score2,get_logp(1024,output,output_det) / 1024,get_logp(1024,output2,output_det2) / 1024



In [7]:
score,score2,output,output2 = get_score_two_models(model,model2,checkpoint,checkpoint2,r'E:/2023/NaverProject/LastCodingProject/Binary_file/X_test_flow.npy')
# score_train,score2_train,logp,logp2 = get_score_two_models(model,model2,checkpoint,checkpoint2,r'E:/2023/NaverProject/LastCodingProject/Binary_file/X_train_flow.npy')

In [11]:
def validate_training(args:CFG,model_constrastive):
    train_loader,_ = get_dataloader(args)
    model_constrastive.train()
    model_constrastive.to('cuda')
    logp_list = []
    for i,data in enumerate(train_loader):
        inp = data.reshape(-1,1024).to('cuda')
        POS = PositionalEncoding1D(args.pos_embed_dim)
        pos = POS(data).to('cuda').reshape(-1,args.pos_embed_dim)
        with torch.no_grad():
            z,logdet = model_constrastive(inp,[pos])

        logp = get_logp(1024,z,logdet) / 1024
        label = torch.hstack([torch.zeros(data.shape[1] // 2), torch.ones(data.shape[1] // 2)]).unsqueeze(
                0).repeat(args.Batch_size, 1).to(args.device).reshape(-1)  # [4,32]
        logp_list.append(logp.reshape(-1,64))
    return logp_list
        
args = CFG()
model.load_state_dict(torch.load(checkpoint2))
# logp_list = validate_training(args,model)

<All keys matched successfully>

In [10]:
# logps_tensor = torch.concat(logp_list,dim = 0).reshape(-1)
# label = torch.hstack([torch.zeros(32), torch.ones(32)]).unsqueeze(
#         0).repeat(810, 1).reshape(-1)

def save_visualization2(logps, labels, name_fig):
    """
    :param logps: [Batch_size * 32]
    :param labels: [Batch_size * 32]
    :return:
    """
    logp_normal = logps[labels == 0]
    logp_abnormal = logps[labels != 0]
    print(logp_normal.shape,logp_abnormal.shape)
    n_idx = int(len(logp_normal) * 0.4)
    sorted_indices = torch.sort(logp_normal)[1]
    n_idx = sorted_indices[n_idx]
    b_n = logp_normal[n_idx]
    plt.figure()
    sns.distplot(logp_normal.detach().cpu().numpy(), label='normal')
    sns.distplot(logp_abnormal.detach().cpu().numpy(), label='abnormal')
    plt.axvline(b_n.cpu().item(), color='red', linestyle='--')
    plt.legend()
    if name_fig is not None:
        plt.savefig(name_fig)
    else:
        plt.show()
    plt.close()


# save_visualization2(logps_tensor,label,None)

In [43]:
def postpress(curve, seg_size=32):
    leng = curve.shape[0]
    window_size = leng // seg_size
    new_curve = np.zeros_like(curve)
    for i in range(seg_size):
        new_curve[window_size * i:window_size * (i + 1)] = np.mean(curve[window_size * i:window_size * (i + 1)])
    if leng > window_size * seg_size:
        new_curve[seg_size * window_size:] = np.mean(curve[seg_size * window_size:])
    return new_curve


def evaluate_result_tmp(score, label_path):
    videos = {}
    with open(label_path, 'r') as f:
        for idx, line in enumerate(f):
            video_len = int(line.strip().split(' ')[1])
            sub_video_gt = np.zeros((video_len,), dtype=np.int8)
            anomaly_tuple = line.split(' ')[3:]
            for ind in range(len(anomaly_tuple) // 2):
                start = int(anomaly_tuple[2 * ind])
                end = int(anomaly_tuple[2 * ind + 1])
                if start > 0:
                    sub_video_gt[start:end] = 1
            videos[idx] = sub_video_gt

    GT = []
    ANS = []

    GT_matrix = []
    ANS_matrix = []

    for vid in videos:
        cur_ab = score[vid]
        cur_gt = videos[vid]
        ratio = float(len(cur_gt)) / float(len(cur_ab))
        cur_ans = np.zeros_like(cur_gt, dtype='float32')
        for i in range(len(cur_ab)):
            b = int(i * ratio + 0.5)
            e = int((i + 1) * ratio + 0.5)
            cur_ans[b: e] = cur_ab[i]
        cur_ans = postpress(cur_ans, seg_size=64)
        GT_matrix.append(cur_gt.tolist())
        ANS_matrix.append(cur_ans.tolist())
        GT.extend(cur_gt.tolist())
        ANS.extend(cur_ans.tolist())
    # for i, (gt, ans) in enumerate(tqdm(zip(GT_matrix, ANS_matrix))):
    #     plt.figure()
    #     plt.plot(gt, color='blue')
    #     plt.plot(ans, color='red')
    #     plt.savefig(os.path.join('Result', f"{i}.png"))
    #     plt.close()
    return roc_auc_score(GT, ANS),GT_matrix,ANS_matrix


In [8]:
def validate_logp_normal(logp,label_train):
    if logp.dim() == 1:
        logp = logp.reshape(-1,32)
    logp_normal = logp[label_train == 0]
    logp_abnormal = logp[label_train != 0]
    logp_normal = logp_normal.reshape(-1).detach().cpu()
    logp_abnormal = logp_abnormal.reshape(-1).detach().cpu()
    
    n_idx = int(len(logp_normal) * 0.4)
    sorted_indices = torch.sort(logp_normal)[1]
    n_idx = sorted_indices[n_idx]
    b_n = logp_normal[n_idx] 
    
    sns.distplot(logp_normal,label = 'normal')
    sns.distplot(logp_abnormal,label = 'abnormal')
    plt.axvline(b_n, color='red', linestyle='--')

    plt.legend()
    plt.show()

# validate_logp_normal(logp2,torch.load(r"E:\2023\NaverProject\LastCodingProject\Binary_file\label_train_flow.pt"))