-
Notifications
You must be signed in to change notification settings - Fork 5
/
engine.py
152 lines (134 loc) · 5.98 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import time
from typing import Callable
import faiss
import numpy as np
import torch
from torch import nn
from torchvision.transforms import functional
from . import utils, utils_img
from .attenuations import JND
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def get_targets(
target: str,
index: faiss.Index,
fts: torch.Tensor,
ivf_centroids: np.ndarray = None
) -> torch.Tensor:
"""
Get the target representations for the features.
Args:
target (str): Target representation to use.
index (faiss.Index): Index to use for retrieval.
fts (torch.Tensor): Features to get the targets for. batch_size x feature_dim
ivf_centroids (np.ndarray): Centroids of the IVF index.
Returns:
targets (torch.Tensor): Target representations for the features. batch_size x feature_dim
"""
if target == 'pq_recons':
targets = index.reconstruct_n(index.ntotal-fts.shape[0], fts.shape[0]) # reconstruct the PQ codes that have just been added
targets = torch.tensor(targets)
elif target == 'ori_ft':
fts.clone()
elif target == 'ivf_cluster':
ivf_D, ivf_I = index.quantizer.search(fts.detach().cpu().numpy(), k=1) # find the closest cluster center for each feature
targets = ivf_centroids.take(ivf_I.flatten(), axis=0) # get the cluster representation for each feature
targets = torch.tensor(targets)
elif target == 'ivf_cluster_half':
ivf_D, ivf_I = index.quantizer.search(fts.detach().cpu().numpy(), k=1)
centroids = ivf_centroids.take(ivf_I.flatten(), axis=0)
targets = (torch.tensor(centroids) + fts.clone() / 2)
else:
raise NotImplementedError(f'Invalid target: {target}')
return targets
def activate_images(
imgs: list[torch.Tensor],
ori_fts: torch.Tensor,
model: nn.Module,
index: faiss.Index,
ivf_centroids: np.ndarray,
attenuation: JND,
loss_f: Callable,
loss_i: Callable,
params: argparse.Namespace
) -> list[torch.Tensor]:
"""
Activate images.
Args:
imgs (list of torch.Tensor): Images to activate. batch_size * [3 x height x width]
model (torch.nn.Module): Model for feature extraction.
index (faiss.Index): Index to use for retrieval.
ivf_centroids (np.ndarray): Centroids of the IVF index.
attenuation (JND): To create Just Noticeable Difference heatmaps.
loss_f (Callable): Loss function to use for the indexation loss.
loss_i (Callable): Loss function to use for the image loss.
params (argparse.Namespace): Parameters.
Returns:
activated images (list of torch.Tensor): Activated images. batch_size * [3 x height x width]
"""
targets = get_targets(params.target, index, ori_fts, ivf_centroids)
targets = targets.to(device)
# Just noticeable difference heatmaps
alpha = torch.tensor([0.072*(1/0.299), 0.072*(1/0.587), 0.072*(1/0.114)])
alpha = alpha[:,None,None].to(device) # 3 x 1 x 1
heatmaps = [params.scaling * attenuation.heatmaps(img) for img in imgs]
# init distortion + optimizer + scheduler
deltas = [1e-6 * torch.randn_like(img).to(device) for img in imgs] # b (1 c h w)
for distortion in deltas:
distortion.requires_grad = True
optim_params = utils.parse_params(params.optimizer)
optimizer = utils.build_optimizer(model_params=deltas, **optim_params)
if params.scheduler is not None:
scheduler = utils.build_scheduler(optimizer=optimizer, **utils.parse_params(params.scheduler))
# begin optim
iter_time = time.time()
log_stats = []
for gd_it in range(params.iterations):
gd_it_time = time.time()
if params.scheduler is not None:
scheduler.step(gd_it)
# perceptual constraints
percep_deltas = [torch.tanh(delta) for delta in deltas] if params.use_tanh else deltas
percep_deltas = [delta * alpha for delta in percep_deltas] if params.scale_channels else percep_deltas
imgs_t = [img + hm * delta for img, hm, delta in zip(imgs, heatmaps, percep_deltas)]
# get features
batch_imgs = [functional.resize(img_t, (params.resize_size, params.resize_size)) for img_t in imgs_t]
batch_imgs = torch.stack(batch_imgs)
fts = model(batch_imgs) # b d
# compute losses
lf = loss_f(fts, targets)
li = loss_i(imgs_t, imgs)
loss = params.lambda_f * lf + params.lambda_i * li
# step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# log stats
psnrs = torch.tensor([utils_img.psnr(img_t, img) for img_t, img in zip(imgs_t, imgs)])
linfs = torch.tensor([utils_img.linf(img_t, img) for img_t, img in zip(imgs_t, imgs)])
log_stats.append({
'gd_it': gd_it,
'loss': loss.item(),
'loss_f': lf.item(),
'loss_i': li.item(),
'psnr': torch.nanmean(psnrs).item(),
'linf': torch.nanmean(linfs).item(),
'lr': optimizer.param_groups[0]['lr'],
'gd_it_time': time.time() - gd_it_time,
'iter_time': time.time() - iter_time,
'max_mem': torch.cuda.max_memory_allocated() / (1024*1024),
'kw': 'optim',
})
if (gd_it+1) % params.log_freq == 0:
print(json.dumps(log_stats[-1]))
# tqdm.tqdm.write(json.dumps(log_stats[-1]))
# perceptual constraints
percep_deltas = [torch.tanh(delta) for delta in deltas] if params.use_tanh else deltas
percep_deltas = [delta * alpha for delta in percep_deltas] if params.scale_channels else percep_deltas
imgs_t = [img + hm * delta for img, hm, delta in zip(imgs, heatmaps, percep_deltas)]
return imgs_t