-
Notifications
You must be signed in to change notification settings - Fork 19
/
cpc.py
196 lines (170 loc) · 9 KB
/
cpc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/17_cpc.ipynb (unless otherwise specified).
__all__ = ['CPCEncoder', 'CPCModel']
# Cell
import torch
import torch.nn.functional as F
import torch.nn as nn
from clinical_ts.basic_conv1d import _conv1d
import numpy as np
from .basic_conv1d import listify, bn_drop_lin
# Cell
class CPCEncoder(nn.Sequential):
'CPC Encoder'
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn=False):
assert(len(strides)==len(kss) and len(strides)==len(features))
lst = []
for i,(s,k,f) in enumerate(zip(strides,kss,features)):
lst.append(_conv1d(input_channels if i==0 else features[i-1],f,kernel_size=k,stride=s,bn=bn))
super().__init__(*lst)
self.downsampling_factor = np.prod(strides)
self.output_dim = features[-1]
# output: bs, output_dim, seq//downsampling_factor
def encode(self, input):
#bs = input.size()[0]
#ch = input.size()[1]
#seq = input.size()[2]
#segments = seq//self.downsampling_factor
#input_encoded = self.forward(input[:,:,:segments*self.downsampling_factor]).transpose(1,2) #bs, seq//downsampling, encoder_output_dim (standard ordering for batch_first RNNs)
input_encoded = self.forward(input).transpose(1,2)
return input_encoded
# Cell
class CPCModel(nn.Module):
"CPC model"
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn_encoder=False, n_hidden=512,n_layers=2,mlp=False,lstm=True,bias_proj=False, num_classes=None, concat_pooling=True, ps_head=0.5,lin_ftrs_head=[512],bn_head=True,skip_encoder=False):
super().__init__()
assert(skip_encoder is False or num_classes is not None)#pretraining only with encoder
self.encoder = CPCEncoder(input_channels,strides=strides,kss=kss,features=features,bn=bn_encoder) if skip_encoder is False else None
self.encoder_output_dim = self.encoder.output_dim if skip_encoder is False else None
self.encoder_downsampling_factor = self.encoder.downsampling_factor if skip_encoder is False else None
self.n_hidden = n_hidden
self.n_layers = n_layers
self.mlp = mlp
self.num_classes = num_classes
self.concat_pooling = concat_pooling
self.rnn = nn.LSTM(self.encoder_output_dim if skip_encoder is False else input_channels,n_hidden,num_layers=n_layers,batch_first=True) if lstm is True else nn.GRU(self.encoder.output_dim,n_hidden,num_layers=n_layers,batch_first=True)
if(num_classes is None): #pretraining
if(mlp):# additional hidden layer as in simclr
self.proj = nn.Sequential(nn.Linear(n_hidden, n_hidden),nn.ReLU(inplace=True),nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj))
else:
self.proj = nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj)
else: #classifier
#slightly adapted from RNN1d
layers_head =[]
if(self.concat_pooling):
layers_head.append(AdaptiveConcatPoolRNN())
#classifier
nf = 3*n_hidden if concat_pooling else n_hidden
lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
ps_head = listify(ps_head)
if len(ps_head)==1:
ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None]
for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns):
layers_head+=bn_drop_lin(ni,no,bn_head,p,actn)
self.head=nn.Sequential(*layers_head)
def forward(self, input):
# input shape bs,ch,seq
if(self.encoder is not None):
input_encoded = self.encoder.encode(input)
else:
input_encoded = input.transpose(1,2) #bs, seq, channels
output_rnn, _ = self.rnn(input_encoded) #output_rnn: bs, seq, n_hidden
if(self.num_classes is None):#pretraining
return input_encoded, self.proj(output_rnn)
else:#classifier
output = output_rnn.transpose(1,2)#bs,n_hidden,seq (i.e. standard CNN channel ordering)
if(self.concat_pooling is False):
output = output[:,:,-1]
return self.head(output)
def get_layer_groups(self):
return (self.encoder,self.rnn,self.head)
def get_output_layer(self):
return self.head[-1]
def set_output_layer(self,x):
self.head[-1] = x
def cpc_loss(self,input, target=None, steps_predicted=5, n_false_negatives=9, negatives_from_same_seq_only=False, eval_acc=False):
assert(self.num_classes is None)
input_encoded, output = self.forward(input) #input_encoded: bs, seq, features; output: bs,seq,features
input_encoded_flat = input_encoded.reshape(-1,input_encoded.size(2)) #for negatives below: -1, features
bs = input_encoded.size()[0]
seq = input_encoded.size()[1]
loss = torch.tensor(0,dtype=torch.float32).to(input.device)
tp_cnt = torch.tensor(0,dtype=torch.int64).to(input.device)
for i in range(input_encoded.size()[1]-steps_predicted):
positives = input_encoded[:,i+steps_predicted].unsqueeze(1) #bs,1,encoder_output_dim
if(negatives_from_same_seq_only):
idxs = torch.randint(0,(seq-1),(bs*n_false_negatives,)).to(input.device)
else:#negative from everywhere
idxs = torch.randint(0,bs*(seq-1),(bs*n_false_negatives,)).to(input.device)
idxs_seq = torch.remainder(idxs,seq-1) #bs*false_neg
idxs_seq2 = idxs_seq * (idxs_seq<(i+steps_predicted)).long() +(idxs_seq+1)*(idxs_seq>=(i+steps_predicted)).long()#bs*false_neg
if(negatives_from_same_seq_only):
idxs_batch = torch.arange(0,bs).repeat_interleave(n_false_negatives).to(input.device)
else:
idxs_batch = idxs//(seq-1)
idxs2_flat = idxs_batch*seq+idxs_seq2 #for negatives from everywhere: this skips step i+steps_predicted from the other sequences as well for simplicity
negatives = input_encoded_flat[idxs2_flat].view(bs,n_false_negatives,-1) #bs*false_neg, encoder_output_dim
candidates = torch.cat([positives,negatives],dim=1)#bs,false_neg+1,encoder_output_dim
preds=torch.sum(output[:,i].unsqueeze(1)*candidates,dim=-1) #bs,(false_neg+1)
targs = torch.zeros(bs, dtype=torch.int64).to(input.device)
if(eval_acc):
preds_argmax = torch.argmax(preds,dim=-1)
tp_cnt += torch.sum(preds_argmax == targs)
loss += F.cross_entropy(preds,targs)
if(eval_acc):
return loss, tp_cnt.float()/bs/(input_encoded.size()[1]-steps_predicted)
else:
return loss
#copied from RNN1d
class AdaptiveConcatPoolRNN(nn.Module):
def __init__(self, bidirectional=False):
super().__init__()
self.bidirectional = bidirectional
def forward(self,x):
#input shape bs, ch, ts
t1 = nn.AdaptiveAvgPool1d(1)(x)
t2 = nn.AdaptiveMaxPool1d(1)(x)
if(self.bidirectional is False):
t3 = x[:,:,-1]
else:
channels = x.size()[1]
t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1)
out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch
return out
#class CPCClassifier(nn.Module):
# def __init__(self, cpcmodel, num_classes, concat_pooling=True, ps_head=0.5,lin_ftrs_head=None,bn_head=True):
# super().__init__()
# self.cpcmodel = cpcmodel
# self.concat_pooling = concat_pooling
#
# #slightly adapted from RNN1d
# layers_head =[]
# if(self.concat_pooling):
# layers_head.append(AdaptiveConcatPoolRNN())
#
# #classifier
# nf = 3*self.cpcmodel.encoder_output_dim if concat_pooling else self.cpcmodel.encoder_output_dim
# lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
# ps_head = listify(ps_head)
# if len(ps_head)==1:
# ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head
# actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None]
#
# for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns):
# layers_head+=bn_drop_lin(ni,no,bn_head,p,actn)
# self.head=nn.Sequential(*layers_head)
#
# def forward(self,input):
# output = self.cpcmodel(input)
# if(self.concat_pooling is False):
# output = output[:,:,-1]
# return self.head(output)
#
# def get_layer_groups(self):
# return (self.cpcmodel.encoder,self.cpcmodel.rnn,self.head)
#
# def get_output_layer(self):
# return self.head[-1]
#
# def set_output_layer(self,x):
# self.head[-1] = x