In [1]:
import torch
from src.helper import load_checkpoint, init_model
from src.datasets.ukbb import make_ukbb
from src.models.vision_transformer import vit_custom
from src.masks.multiblock import MaskCollator as MBMaskCollator
import yaml
from src.transforms import make_transforms
import pprint
import torch.nn.functional as F
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score


In [20]:
with open("configs/configs_vitt.yaml", 'r') as y_file:
        args = yaml.load(y_file, Loader=yaml.FullLoader)
        pp = pprint.PrettyPrinter(indent=4)

# -- META
use_bfloat16 = args['meta']['use_bfloat16']
model_name = args['meta']['model_name']
load_model = args['meta']['load_checkpoint'] or False
r_file = args['meta']['read_checkpoint']
copy_data = args['meta']['copy_data']
pred_depth = args['meta']['pred_depth']
pred_emb_dim = args['meta']['pred_emb_dim']
if not torch.cuda.is_available():
    device = torch.device('cpu')
else:
    device = torch.device('cuda:0')
    torch.cuda.set_device(device)

# -- DATA
rescale_sigma = args['data']['rescale_sigma']
ftsurrogate = args['data']['ftsurrogate']
jitter = args['data']['jitter']
spec_augment = args['data']['spec_augment']
time_flip = args['data']['time_flip']
sign_flip = args['data']['sign_flip']
# --
batch_size = args['data']['batch_size']
pin_mem = args['data']['pin_mem']
num_workers = args['data']['num_workers']
root_path = args['data']['root_path']
image_folder = args['data']['data_path']
val_folder = args['data']['val_path']
downstream_train_path = args['data']['downstream_train_path']
downstream_val_path = args['data']['downstream_val_path']
crop_size = args['data']['crop_size']
crop_scale = args['data']['crop_scale']
# --

# -- MASK
allow_overlap = args['mask']['allow_overlap']  # whether to allow overlap b/w context and target blocks
patch_size = args['mask']['patch_size']  # patch-size for model training
num_enc_masks = args['mask']['num_enc_masks']  # number of context blocks
min_keep = args['mask']['min_keep']  # min number of patches in context block
enc_mask_scale = args['mask']['enc_mask_scale']  # scale of context blocks
num_pred_masks = args['mask']['num_pred_masks']  # number of target blocks
pred_mask_scale = args['mask']['pred_mask_scale']  # scale of target blocks
aspect_ratio = args['mask']['aspect_ratio']  # aspect ratio of target blocks
# --    
# -- make data transforms
mask_collator = MBMaskCollator(
    input_size=crop_size,
    patch_size=patch_size,
    pred_mask_scale=pred_mask_scale,
    enc_mask_scale=enc_mask_scale,
    aspect_ratio=aspect_ratio,
    nenc=num_enc_masks,
    npred=num_pred_masks,
    allow_overlap=allow_overlap,
    min_keep=min_keep)

transform = make_transforms(
    crop_resizing=crop_size,
    ftsurrogate=ftsurrogate,
    jitter=jitter,
    rescale_sigma=rescale_sigma,
    time_flip=time_flip,
    sign_flip=sign_flip,
    spec_augment = spec_augment
    )

_, downstream_train_loader,_ = make_ukbb(
        transform=None,
        batch_size=batch_size,
        collator=mask_collator,
        pin_mem=pin_mem,
        training=True,
        num_workers=num_workers,
        world_size=1,
        rank=0,
        root_path=root_path,
        data_file=downstream_train_path,
        copy_data=False,
        drop_last=True
)
_, downstream_val_loader,_ = make_ukbb(
        transform=None,
        batch_size=batch_size,
        collator=mask_collator,
        pin_mem=pin_mem,
        training=True,
        num_workers=num_workers,
        world_size=1,
        rank=0,
        root_path=root_path,
        data_file=downstream_val_path,
        copy_data=False,
        drop_last=True
)

INFO:root:making ecg data transforms
INFO:root:Initialized UKBB
INFO:root:UKBB dataset created
INFO:root:Ukbb unsupervised data loader created
INFO:root:Initialized UKBB
INFO:root:UKBB dataset created
INFO:root:Ukbb unsupervised data loader created


In [21]:
path= "/vol/aimspace/users/seel/wandb/run-20240206_114842-i17qy14j/files/jepa-latest.pth.tar"

In [22]:
encoder, predictor = init_model(
        device="cuda:0",
        patch_size=(1,100),
        crop_size=(12,500),
        pred_depth=1,
        pred_emb_dim=96,
        model_name=model_name)




INFO:root:VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(1, 192, kernel_size=(1, 100), stride=(1, 100))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
)


In [23]:
encoder.pos_embed.shape

torch.Size([1, 600, 192])

In [24]:
checkpoint = torch.load(path, map_location=torch.device('cpu'))
epoch = checkpoint['epoch']
pretrained_dict = checkpoint['encoder']
checkpoint.keys()


dict_keys(['encoder', 'predictor', 'target_encoder', 'opt', 'scaler', 'epoch', 'loss', 'batch_size', 'world_size', 'lr'])

In [25]:
print(checkpoint["scaler"])

None


In [26]:

pretrained_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
msg = encoder.load_state_dict(pretrained_dict)
print(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

loaded pretrained encoder from epoch 200 with msg: <All keys matched successfully>


In [27]:
for p in encoder.parameters():
    p.requires_grad = False

In [11]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self,encoder):
        super().__init__()
        self.encoder = encoder
        
        self.fc1 = nn.Linear(384, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = self.encoder(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x


net = Net(encoder)


In [29]:
from tqdm import tqdm
encoder.eval()
encodings_train = torch.tensor([])
labels_train = torch.tensor([])
encodings_val = torch.tensor([])
labels_val = torch.tensor([])

for itr, (udata, masks_enc, masks_pred) in tqdm(enumerate(downstream_train_loader)):
    def load_imgs():
        # -- unsupervised imgs
        imgs = udata[0].to(device, non_blocking=True)
        labels = udata[1]
        
        return (imgs, labels)
    imgs, labels = load_imgs()
    labels_train=torch.cat((labels_train,labels.cpu()), 0)
    def forward_target():
        with torch.no_grad():
            h = encoder(imgs)
            h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
            return h
    # Step 1. Forward
    with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=False):
        h = forward_target() # shape of h: (B,600,768) e.g. B=32
        encodings_train = torch.cat((encodings_train,h.detach().cpu()), 0)
for itr, (udata, masks_enc, masks_pred) in enumerate(downstream_val_loader):
    def load_imgs():
        # -- unsupervised imgs
        imgs = udata[0].to(device, non_blocking=True)
        labels = udata[1]
        return (imgs, labels)
    imgs, labels = load_imgs()
    labels_val=torch.cat((labels_val,labels.cpu()), 0)
    def forward_target():
        with torch.no_grad():
            h = encoder(imgs)
            h = F.layer_norm(h, (h.size(-1),))  # normalize over feature-dim
            return h
    # Step 1. Forward
    with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=False):
        h = forward_target()
        
        encodings_val=torch.cat((encodings_val,h.detach().cpu()), 0)
        
encodings_train = encodings_train.mean(dim=1)
encodings_val = encodings_val.mean(dim=1)

13it [00:03,  3.33it/s]


In [34]:
pipe = make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000,random_state=0,C=0.001))

pipe.fit(
    np.asarray(encodings_train), #.reshape(len(encodings_train),-1)
    np.asarray(labels_train).flatten())
            
train_proba = pipe.predict_proba(
    np.asarray(encodings_train),  #.reshape(len(encodings_train),-1)
    )[:, 1]
            
train_pred = pipe.predict(
    np.asarray(encodings_train),  #.reshape(len(encodings_train),-1)
    )
            
train_acc = accuracy_score(np.asarray(labels_train).flatten(), train_pred)
train_auc = roc_auc_score(np.asarray(labels_train).flatten(), train_proba)
train_f1 = f1_score(np.asarray(labels_train).flatten(), train_pred)
            
val_pred = pipe.predict(
    np.asarray(encodings_val), #.reshape(len(encodings_val),-1)
    )
            
val_proba = pipe.predict_proba(
    np.asarray(encodings_val), #.reshape(len(encodings_val),-1)
    )[:, 1]
            
val_acc = accuracy_score(np.asarray(labels_val).flatten(), val_pred)
val_auc = roc_auc_score(np.asarray(labels_val).flatten(), val_proba)
val_f1 = f1_score(np.asarray(labels_val).flatten(), val_pred)

print(train_acc, train_auc, train_f1)
print(val_acc,val_auc,val_f1)

0.6670673076923077 0.7242697408338389 0.6654589371980677
0.64453125 0.70257568359375 0.6459143968871596


In [None]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

criterion = torch.nn.BCELoss()
net = net.to(device=device)

In [None]:
for epoch in range(50):  # loop over the dataset multiple times

    running_loss = 0.0
    for itr, (udata, masks_enc, masks_pred) in enumerate(downstream_val_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = udata[0].to(device, non_blocking=True), udata[1].float().to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        outputs=outputs.mean(dim=1)
        #print(outputs, labels)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if itr % 20 == 19:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {itr + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

[1,    20] loss: 0.004
[2,    20] loss: 0.004
[3,    20] loss: 0.004
[4,    20] loss: 0.004
[5,    20] loss: 0.003
[6,    20] loss: 0.003
[7,    20] loss: 0.003
[8,    20] loss: 0.003
[9,    20] loss: 0.003
[10,    20] loss: 0.003
[11,    20] loss: 0.003
[12,    20] loss: 0.003
[13,    20] loss: 0.003
[14,    20] loss: 0.003
[15,    20] loss: 0.003
[16,    20] loss: 0.003
[17,    20] loss: 0.003
[18,    20] loss: 0.003
[19,    20] loss: 0.003
[20,    20] loss: 0.003
[21,    20] loss: 0.003
[22,    20] loss: 0.003
[23,    20] loss: 0.003
[24,    20] loss: 0.003
[25,    20] loss: 0.003
[26,    20] loss: 0.003
[27,    20] loss: 0.003
[28,    20] loss: 0.003
[29,    20] loss: 0.003
[30,    20] loss: 0.003
[31,    20] loss: 0.003
[32,    20] loss: 0.003
[33,    20] loss: 0.003
[34,    20] loss: 0.003
[35,    20] loss: 0.003
[36,    20] loss: 0.003
[37,    20] loss: 0.003
[38,    20] loss: 0.003
[39,    20] loss: 0.003
[40,    20] loss: 0.003
[41,    20] loss: 0.003
[42,    20] loss: 0.003
[

In [None]:

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(downstream_train_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = net(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(downstream_val_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss