In [1]:
import os
import sys

# sys.path.insert(0, os.getcwd())

import tensorflow as tf
physical_devices = tf.config.list_physical_devices("GPU")
print(physical_devices)
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, True)

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, IterableDataset


2022-02-01 01:11:26.225285: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-01 01:11:26.250301: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-01 01:11:26.250414: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
def get_all_files(path, prefix="", suffix="", contains=""):
    if not os.path.isdir(path):
        raise ValueError(f"{path} is not a valid directory.")
    files = []
    for pre, dirs, basenames in os.walk(path):
        for name in basenames:
            if name.startswith(prefix) and name.endswith(suffix) and contains in name:
                files.append(os.path.join(pre, name))
    return files


In [3]:
DB_STAT_DIR = "/home/tai/1-workdir/1-deepfake-transformer/src/dataset_stuff/image_generator/db_stats"
PATCH_SIZE = 128
ROOT_DB_DIR = f"/media/nas2/misl_image_db_70_class"
TRAIN_DS_PATH = f"{ROOT_DB_DIR}/train/{PATCH_SIZE}"
VAL_DS_PATH = f"{ROOT_DB_DIR}/val/{PATCH_SIZE}"

NUM_CLASSES = 70
BATCH_SIZE = 64


train_recs = get_all_files(TRAIN_DS_PATH, suffix=".tfrecord")
val_recs = get_all_files(VAL_DS_PATH, suffix=".tfrecord")


AUTOTUNE = tf.data.AUTOTUNE

image_feature_description = {
    "raw": tf.io.FixedLenFeature([], tf.string),
    "label": tf.io.FixedLenFeature([], tf.int64),
}

In [4]:
def _parse_image_function(example_proto):
    parsed_feature = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.io.parse_tensor(parsed_feature["raw"], tf.float32)
    image = tf.reshape(image, [PATCH_SIZE, PATCH_SIZE, 3])
    label = tf.cast(parsed_feature["label"], tf.int64)
    return image, label


raw_train_set = tf.data.Dataset.from_tensor_slices(train_recs).interleave(
    lambda x: tf.data.TFRecordDataset(x).map(_parse_image_function, num_parallel_calls=AUTOTUNE),
    num_parallel_calls=AUTOTUNE,
    cycle_length=NUM_CLASSES,
    block_length=2,
)
raw_val_set = tf.data.TFRecordDataset(val_recs).map(_parse_image_function)


train_tfds = raw_train_set.shuffle(buffer_size=BATCH_SIZE * 2).batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
val_tfds = raw_val_set.batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

2022-02-01 01:11:27.507739: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-02-01 01:11:27.509371: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-01 01:11:27.509612: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-02-01 01:11:27.509793: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zer

In [5]:
class MyIterableDataset(IterableDataset):
    def __init__(self, generator):
        self.generator = generator

    def process_data(self, generator):
        for image, label in generator:
            image = torch.from_numpy(image.numpy()).permute(0, 3, 1, 2)  # BHWC->BCHW
            label = torch.from_numpy(label.numpy()).long()
            yield image, label

    def get_stream(self, generator):
        return self.process_data(generator)

    def __iter__(self):
        return self.get_stream(self.generator)


In [6]:
train_itds = MyIterableDataset(train_tfds)
val_itds = MyIterableDataset(val_tfds)
train_dl = DataLoader(train_itds, batch_size=None, num_workers=0)
val_dl = DataLoader(val_itds, batch_size=None, num_workers=0)

In [7]:
from mislnet import MISLnetPLWrapper

In [8]:
config = {
    "input_size": (128, 128),
    "output_dim": 1024,
    "num_classes": 70,
    "lr": 1e-3,
    "momentum": 0.95,
    "decay_rate": 0.75,
    "decay_step": 4,
}

model = MISLnetPLWrapper(config)
model_name = "mislnet-128-1024"




In [9]:
prev_ckpt = None
prev_ckpt = "/home/tai/1-workdir/5-forensics-barlow-twins/src/lightning_logs/version_1/checkpoints/mislnet-128-1024=0-epoch=184-val_loss=0.9704.ckpt"

resume = True
if prev_ckpt:
    model = model.load_from_checkpoint(prev_ckpt, args=config)

version = 1
monitor_metric = "val_loss"
logger = TensorBoardLogger(save_dir=os.getcwd(), version=version, name="src/lightning_logs")
lr_monitor = LearningRateMonitor(logging_interval="step")
model_ckpt = ModelCheckpoint(
    dirpath=f"src/lightning_logs/version_{version}/checkpoints",
    monitor=monitor_metric,
    filename=f"{{{model_name}}}-{{epoch:02d}}-{{{monitor_metric}:.4f}}",
    verbose=True,
    save_last=True,
    mode="min",
)

trainer = Trainer(
    gpus=1,
    max_epochs=200,
    resume_from_checkpoint=prev_ckpt if resume else None,
    progress_bar_refresh_rate=100,
    weights_summary="full",
    logger=logger,
    callbacks=[lr_monitor, model_ckpt],
    fast_dev_run=False,
)


  rank_zero_deprecation(
  rank_zero_deprecation(
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [10]:
trainer.fit(model, train_dl, val_dl)


  rank_zero_deprecation(
Restoring states from the checkpoint path at /home/tai/1-workdir/5-forensics-barlow-twins/src/lightning_logs/version_1/checkpoints/mislnet-128-1024=0-epoch=184-val_loss=0.9704.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(
Restored all states from the checkpoint file at /home/tai/1-workdir/5-forensics-barlow-twins/src/lightning_logs/version_1/checkpoints/mislnet-128-1024=0-epoch=184-val_loss=0.9704.ckpt

   | Name           | Type        | Params | In sizes         | Out sizes       
--------------------------------------------------------------------------------------
0  | model          | MISLnet     | 3.5 M  | [1, 3, 128, 128] | [1, 70]         
1  | model.model    | Sequential  | 3.5 M  | [1, 3, 128, 128] | [1, 70]         
2  | model.model.0  | Conv2d      | 228    | [1, 3, 128, 128] | [1, 3, 124, 124]
3  | model.model.1  | Conv2d      | 14.2 K | [1, 3, 124, 124] | [1, 96, 64, 64] 
4  | model.model.2  | BatchNorm2d | 192    | [1, 96, 64, 

Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Epoch 185, global step 5813987: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 186, global step 5845245: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 187, global step 5876503: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 188, global step 5907761: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 189, global step 5939019: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 190, global step 5970277: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 191, global step 6001535: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 192, global step 6032793: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 193, global step 6064051: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 194, global step 6095309: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 195, global step 6126567: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 196, global step 6157825: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 197, global step 6189083: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 198, global step 6220341: val_loss was not in top 1


Validating: 0it [00:00, ?it/s]

Epoch 199, global step 6251599: val_loss was not in top 1
Saving latest checkpoint...
