In [1]:
import pathlib

dataset_root_path = "UCF101_subset"
dataset_root_path = pathlib.Path(dataset_root_path)

test_video_file_paths = (list(dataset_root_path.glob("test/*/*.avi")))
class_labels = sorted({path.parent.name for path in test_video_file_paths})

label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}
print(f"Unique classes: {list(label2id.keys())}.")

Unique classes: ['ApplyEyeMakeup', 'ApplyLipstick', 'Archery', 'BabyCrawling', 'BalanceBeam', 'BandMarching', 'BaseballPitch', 'Basketball', 'BasketballDunk', 'BenchPress'].


In [2]:
import pytorchvideo.data
import os

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    Compose,
    Lambda,
    Resize,
)

mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
resize_to = (224, 224)
num_frames_to_sample = 32
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

val_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to, antialias=False),
                ]
            ),
        ),
    ]
)

test_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "test"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)



In [3]:
def inference(model, batch, device):
    inputs = {
        # bs, 3, 32, 224, 224 - > bs, 32, 3, 224, 224
        "pixel_values": batch['video'].transpose(1, 2)
    }

    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
        predictions = torch.argmax(logits, dim=-1)
    
    return predictions

Initialization

In [4]:
import numpy as np
import torch
import evaluate
from torch.utils.data import DataLoader
from transformers import VivitForVideoClassification
from model_encryption import weight_extracting, weight_reloading

torch.cuda.empty_cache()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

key_dict = np.load("key_dicts/key-32-2-16-seed100.npy", allow_pickle=True).item()
model_ckpt = "checkpoints/vivit-b-16x2-kinetics400-finetuned-ucf101-subset—withouImgP/checkpoint-370"
model = VivitForVideoClassification.from_pretrained(
    model_ckpt,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=False,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
model = model.to(device)
model.eval()

testing_loader = DataLoader(test_dataset, batch_size=5)

Classification with plain videos

In [5]:
metric = evaluate.load("accuracy")

for i, batch in enumerate(testing_loader):
    predictions = inference(model, batch, device)
    metric.add_batch(predictions=predictions, references=batch["label"])
acc = metric.compute()

{'accuracy': 0.9655172413793104}

Classification with encrypted videos

In [6]:
from model_encryption import cube_embedding_shuffling, pos_embedding_shuffling

# model encryption 
ce_weight, pos_weight = weight_extracting(model.vivit.embeddings)
shuffled_ce_weight = cube_embedding_shuffling(ce_weight, key_dict['ce_key'])
shuffled_pos_weight = pos_embedding_shuffling(pos_weight, key_dict['pos_key'])

# reload weights
model.vivit.embeddings = weight_reloading(model.vivit.embeddings, shuffled_ce_weight, pos_weight)

Shuffling weight of Patch embedding...
Shuffling weight of Position embedding...


In [7]:
from video_encryption import *

metric = evaluate.load("accuracy")
for i, batch in enumerate(testing_loader):
    video_tensor = batch['video'].transpose(1, 2)
    
    # encryption 
    cube_group = cube_division(video_tensor)
    cube_group = cube_pix_shuffling(cube_group, key_dict['ce_key'])
    cube_group = cube_pos_shuffling(cube_group, key_dict['pos_key'])
    encrypted_video = cube_integration(cube_group).transpose(1, 2)
    batch['video'] = encrypted_video

    predictions = inference(model, batch, device)
    metric.add_batch(predictions=predictions, references=batch["label"])

acc = metric.compute()

{'accuracy': 0.9655172413793104}

In [8]:
import torch

# 创建一个 4x4 的张量
input_tensor = torch.randn(1, 3, 224, 224).float()

print("输入张量:")
print(input_tensor)

# 使用 unfold 展开 2x2 的块
unfolded = torch.nn.functional.unfold(input_tensor, kernel_size=16, stride=16)

print("\nunfold 后的张量:")
print(unfolded)


输入张量:
tensor([[[[ 0.0602,  1.7181,  1.3446,  ..., -1.1445,  0.9502, -0.7615],
          [-1.8114,  0.1623, -2.0558,  ..., -0.5184, -2.4697, -0.0365],
          [ 0.7450, -0.2364, -0.7323,  ...,  0.3852, -0.3082,  0.7032],
          ...,
          [ 2.1919,  0.2222, -0.7634,  ...,  0.5040,  0.3767, -0.6772],
          [-0.5092,  1.5969, -0.2061,  ..., -1.2855, -0.1036,  0.2316],
          [-1.2236,  0.9498, -0.3940,  ..., -0.7307,  0.1755, -1.3499]],

         [[-0.0567, -0.2995, -0.1699,  ...,  0.4523,  0.3211,  1.5771],
          [-0.5349,  0.9403, -1.4165,  ...,  0.1697,  2.0208,  0.0340],
          [ 1.2738,  0.6613,  0.8278,  ..., -0.7007,  0.1385,  0.3181],
          ...,
          [-1.7200, -0.0846,  0.7557,  ...,  0.2991, -0.1352, -2.1994],
          [ 0.4339,  0.8423,  0.8979,  ..., -0.0877,  0.0953,  2.6410],
          [-0.7873,  0.9222, -1.2158,  ...,  0.4778, -1.4454, -1.1518]],

         [[ 0.9152,  1.8353,  0.4381,  ..., -1.6033, -0.8731, -0.2577],
          [ 2.1494,  0.1

tensor([[[ 0.0602,  0.6254,  1.1399,  ...,  1.0153, -0.3031, -0.6993],
         [ 1.7181,  0.9192, -1.4680,  ..., -1.3723, -1.0965, -0.8489],
         [ 1.3446, -1.1684,  0.4630,  ..., -2.6756,  0.9997, -1.9105],
         ...,
         [-0.9198,  0.1960, -0.3541,  ..., -0.6296, -1.2180, -0.0956],
         [-0.8663,  1.3804,  1.0156,  ..., -0.5488,  1.7293, -0.7356],
         [-0.3611,  0.0061, -0.9260,  ...,  1.9497,  1.1686,  1.2399]]])
