/
noise_detector.py
67 lines (57 loc) · 2.51 KB
/
noise_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from minetorch.plugin import Plugin
from minetorch.statable import Statable
class NoiseSampleDetector(Plugin, Statable):
"""This plugin helps to find out the suspicious noise samples.
provid a metric which compute a scalar for every sample, in most cases
the metric should be the loss function without reduce.
"""
def __init__(self, metric, topn=50):
super().__init__()
self.metric = metric
self.topn = topn
self.train_metrics = []
self.val_metrics = []
def before_init(self):
self.miner.statable[self.__class__.__name__] = self
self.train_dataloader = torch.utils.data.DataLoader(
self.miner.train_dataloader.dataset,
batch_size=self.miner.train_dataloader.batch_size,
num_workers=self.miner.train_dataloader.num_workers,
shuffle=False,
)
self.val_dataloader = torch.utils.data.DataLoader(
self.miner.val_dataloader.dataset,
batch_size=self.miner.train_dataloader.batch_size,
num_workers=self.miner.train_dataloader.num_workers,
shuffle=False,
)
def load_state_dict(self, data):
self.train_metrics = data[0]
self.val_metrics = data[1]
def state_dict(self):
return (self.train_metrics, self.val_metrics)
def after_epoch_end(self, **kwargs):
with torch.no_grad():
self.train_metrics.append(self._predict_dataset(self.train_dataloader))
self.val_metrics.append(self._predict_dataset(self.val_dataloader))
_, train_indices = torch.sort(
torch.std(torch.stack(self.train_metrics), dim=0), descending=True
)
_, val_indices = torch.sort(
torch.std(torch.stack(self.val_metrics), dim=0), descending=True
)
self.print_txt(
f"Train dataset most {self.topn} suspicious indices: {train_indices.tolist()[:self.topn]} \n"
f"Validation dataset most {self.topn} suspicious indices: {val_indices.tolist()[:self.topn]}",
"suspicious_noise_samples",
)
def _predict_dataset(self, dataloader):
results = torch.zeros([len(dataloader.dataset)])
for index, data in enumerate(dataloader):
predict = self.model(data[0].to(self.devices))
offset = index * dataloader.batch_size
results[offset : offset + dataloader.batch_size] = (
self.metric(predict, data[1].to(self.devices)).detach().cpu()
)
return results