### Startup

This code is meant to be executed on Google Colab.
To use it locally change *COLAB_MODE* to False.

**Note**: remember to change *workdir* accordingly, the notebook must be runned inside the root project folder

In [None]:
COLAB_MODE = True

if COLAB_MODE:
  from google.colab import drive
  drive.mount('/content/drive')
  workdir = "/content/drive/<workspace>"
  %cd $workdir

else: workdir = "./"

We use TensorBoardX + Comet.com for logging. (Update the comet config with your data)

In [None]:
# Comet + Tensorboard
!pip install comet_ml --quiet
!pip install tensorboardX --quiet
import comet_ml
from tensorboardX import SummaryWriter

comet_config = {
   "api_key": "<key>", #Insert your Comet API key
   "workspace":"<workspace>", # Comet workspace
   "project_name":"<project>",
   "disabled": False # Disable Comet logging
}


# Dataset

Dataset configuration and builders. Each dataset must be stored in
`datasets/<type>/<dataset_name>.zip`. They will be extracted at runtime in `$destination_path/<dataset_name>` (working directly on Google Drive is not recommended)

Expected dataset structure


```
train/
---0_real/
------img1.png
------img2.png
---1_fake/
------img1.png
------img2.png
val/
---0_real/
------img1.png
------img2.png
---1_fake/
------img1.png
------img2.png
test/
---0_real/
------img1.png
------img2.png
---1_fake/
------img1.png
------img2.png
```





In [3]:
def get_dataset_cfg(name, args):
  if args.ds_cfg["type"] == "guarnera":
    return get_guarnera_cfg(name)
  elif args.ds_cfg["type"] == "cddb":
    return get_cddb_cfg(name)

## Guarnera

Available here: https://iplab.dmi.unict.it/mfs/FightingDeepfake/


In [4]:
import random
import os, zipfile, shutil

def get_guarnera_cfg(name):

  ds_infos = {
      "celeba": {
          "real":       True,
          "zip_path":   workdir + "/datasets/guarnera/CELEBA.zip",
          "n_total":    35976,
          "n_train":    2000,
          "n_val":       1000,
          "n_test":      3000,
          "shuffle":     False,
          "batch":       True,
          },
      "ffhq": {
          "real":       True,
          "zip_path":   workdir + "/datasets/guarnera/FFHQ.zip",
          "n_total":    4000,
          "n_train":    2000,
          "n_val":       600,
          "n_test":      1400,
          },
      "attgan": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/AttGAN.zip",
          "n_total":    6005,
          "n_train":    2000,
          "n_val":       1000,
          "n_test":      3000,
          "batch_id":    0,
          },
      "cyclegan": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/CYCLEGAN.zip",
          "n_total":    2190,
          "n_train":    1000,
          "n_val":       500,
          "n_test":      690,
          "batch_id":    5,
          },
      "faceapp":{
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/FACEAPP.zip",
          "n_total":    471,
          "n_train":    250,
          "n_val":      50,
          "n_test":      171,
          "batch_id":    -1,
          },
      "gdwct": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/GDWCT.zip",
          "n_total":    3367,
          "n_train":    1700,
          "n_val":      500,
          "n_test":      1167,
          "batch_id":    4,
          },
      "imle": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/IMLE.zip",
          "n_total":    2006,
          "n_train":    1000,
          "n_val":      500,
          "n_test":      506,
          "batch_id":    5,
          },
      "progan": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/ProGAN.zip",
          "n_total":    4000,
          "n_train":    2000,
          "n_val":      1000,
          "n_test":     1000,
          "batch_id":   -1,
          },
      "stargan": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/STARGAN.zip",
          "n_total":    5648,
          "n_train":    2000,
          "n_val":      1000,
          "n_test":     2648,
          "batch_id":   1,
          },
      "stylegan": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/STYLEGAN.zip",
          "n_total":    9999,
          "n_train":    2000,
          "n_val":      1000,
          "n_test":      7000,
          "batch_id":    2,
          },
      "stylegan2": {
          "real":       False,
          "zip_path":   workdir + "/datasets/guarnera/STYLEGAN2.zip",
          "n_total":    6000,
          "n_train":    2000,
          "n_val":      1000,
          "n_test":      3000,
          "batch_id":    3,
          },
      }

  return ds_infos.get(name)



def build_guarnera_dataset(real_ds, fake_ds, destination_path="/dataset", erase=False, continual_mode=False):

  def unzip_dataset(ds_info, real_fake , des_path, batch_id=-1, ds_name=''):

    tmp_dir = destination_path+"/tmp"
    shutil.rmtree(tmp_dir, ignore_errors=True)
    os.makedirs(tmp_dir, exist_ok=True)

    zip_path = ds_info["zip_path"]
    n_tr = ds_info["n_train"]
    n_v = ds_info["n_val"]
    n_te = ds_info["n_test"]

    os.makedirs(des_path+"/test/0_real", exist_ok=True)
    os.makedirs(des_path+"/test/1_fake", exist_ok=True)

    os.makedirs(des_path+"/train/0_real", exist_ok=True)
    os.makedirs(des_path+"/train/1_fake", exist_ok=True)

    os.makedirs(des_path+"/val/0_real", exist_ok=True)
    os.makedirs(des_path+"/val/1_fake", exist_ok=True)

    print(f"extracting in {des_path}")
    with zipfile.ZipFile(zip_path) as zip:
      zip_list = zip.infolist()

      start = 0
      limit = n_tr + n_v + n_te + 10
      limit = limit if limit < len(zip_list) else len(zip_list)

      if "CELEBA" in zip_path:
        if ds_info.get("shuffle") is True:
          print("Random sampling")
          random.shuffle(zip_list)
        elif batch_id != -1:
          print(f"Using batch {batch_id}")
          start = batch_id * 6000

      print(f"Extracting images {start} - {start+limit}")
      for i in range(start, start+limit):
        zip.extract(zip_list[i],path=tmp_dir)

      count = 0
      for (path, dirlist, filelist) in os.walk(tmp_dir):

        for filename in filelist:
          if count < n_tr:
            shutil.move(path+"/"+filename, f"{des_path}/train/{real_fake}/{count}_{filename}")
          elif count < n_tr+n_v:
            shutil.move(path+"/"+filename, f"{des_path}/val/{real_fake}/{count}_{filename}")
          elif count < n_tr+n_v+n_te:
            shutil.move(path+"/"+filename, f"{des_path}/test/{real_fake}/{count}_{filename}")

          count += 1

        if count != 0: print(f"{count-1} images extracted")


  if erase: shutil.rmtree(destination_path, ignore_errors=True)

  for ds in fake_ds:
    fake_cfg = get_dataset_cfg(ds, real=False)
    real_cfg = get_dataset_cfg(real_ds)
    batch_id = fake_cfg["batch_id"]
    if continual_mode:
      des_path = f"{destination_path}/{ds}"
    else:
      des_path = destination_path

    print(f"Extracting fake dataset {ds}")
    unzip_dataset(fake_cfg, "1_fake", des_path, ds_name=ds)

    print(f"Extracting real dataset {real_ds}")
    unzip_dataset(real_cfg, "0_real", des_path, batch_id=batch_id)


## CDDB + Diffusion

* CDDB is available here: https://coral79.github.io/CDDB_web/
* The diffusion dataset is custom made

In [5]:
from sklearn.base import MultiOutputMixin
import random
import os, zipfile, shutil

def get_cddb_cfg(name):

  ds_infos = {
      "biggan": {
          "zip_path":   workdir + "/datasets/custom/biggan.zip",
          },
      "crn": {
          "zip_path":   workdir + "/datasets/custom/crn.zip",
          },
      "cyclegan": {
          "zip_path":   workdir + "/datasets/custom/cyclegan.zip",
          },
      "faceforensics": {
          "zip_path":   workdir + "/datasets/custom/faceforensics.zip",
          },
      "gaugan": {
          "zip_path":   workdir + "/datasets/custom/gaugan.zip",
          },
      "glow": {
          "zip_path":   workdir + "/datasets/custom/glow.zip",
          },
      "imle": {
          "zip_path":   workdir + "/datasets/custom/imle.zip",
          },
      "san": {
          "zip_path":   workdir + "/datasets/custom/san.zip",
          },
      "stargan": {
          "zip_path":   workdir + "/datasets/custom/stargan.zip",
          },
      "stylegan": {
          "zip_path":   workdir + "/datasets/custom/stylegan.zip",
          },
      "whichfaceisreal": {
          "zip_path":   workdir + "/datasets/custom/whichfaceisreal.zip",
          },
      "wild": {
          "zip_path":   workdir + "/datasets/custom/wild.zip",
          },
      "diffusionshort": {
          "zip_path":   workdir + "/datasets/custom/diffusionshort.zip",
          },

      }

  return ds_infos.get(name)


# Erase: Empty the destination_path folder
# Continual_mode = extract on destination_path/dataset_name instead of destination_path/
# Limit = Limit the number of extracted images (test, validation, test)
def build_cddb_dataset(ds_list, destination_path="/dataset", erase=False, continual_mode=True, shuffle=False, limit=(100000,100000,100000)):

  def unzip_dataset(ds_info , des_path):

    zip_path = ds_info["zip_path"]
    os.makedirs(des_path, exist_ok=True)

    with zipfile.ZipFile(zip_path) as zip:
      zip_list = zip.infolist()

      if shuffle:
          print("Random sampling")
          random.shuffle(zip_list)

      test_real, train_real, val_real = 0,0,0
      test_fake, train_fake, val_fake = 0,0,0
      for zfile in zip_list:
        if not zfile.is_dir():
          if "train" in zfile.filename:
            if "real" in zfile.filename and train_real < limit[0]:
              zip.extract(zfile,path=des_path)
              train_real += 1
            elif "fake" in zfile.filename and train_fake < limit[0]:
              zip.extract(zfile,path=des_path)
              train_fake += 1

          elif "val" in zfile.filename:
            if "real" in zfile.filename and val_real < limit[1]:
              zip.extract(zfile,path=des_path)
              val_real += 1
            elif "fake" in zfile.filename and val_fake < limit[1]:
              zip.extract(zfile,path=des_path)
              val_fake += 1

          elif "test" in zfile.filename:
            if "real" in zfile.filename and test_real < limit[2]:
              zip.extract(zfile,path=des_path)
              test_real += 1
            elif "fake" in zfile.filename and test_fake < limit[2]:
              zip.extract(zfile,path=des_path)
              test_fake += 1

      print(f"TOT (for each class): {(train_real+val_real+test_real)}, train {train_real}, val {val_real}, test {test_real}\n")

  if erase: shutil.rmtree(destination_path, ignore_errors=True)

  for ds_cluster in ds_list:
    if continual_mode:
      des_path = f"{destination_path}/{ds_cluster}"
    else:
      des_path = destination_path

    cluster_list = ds_cluster.split(".")
    for ds in cluster_list:
      cfg = get_cddb_cfg(ds)
      print(f"Extracting dataset {ds} in {des_path}")
      unzip_dataset(cfg, des_path)



# Transfer Learning

In [6]:
%load_ext autoreload
%autoreload 2

import sys
from common_functions import *
from cored_functions import *
from torch.cuda.amp import autocast, GradScaler
from torchsummary import summary
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
from comet_ml.integration.pytorch import log_model


def tl_train(args, log=None):
    # Init
    lr = args.lr
    savepath = f"{args.checkpoints_dir}/{args.name_target}/"
    savepath = savepath.replace('//','/')
    if not os.path.isdir(savepath):
        os.makedirs(savepath)
    print(f'save path : {savepath}')

    # Logger
    writer = SummaryWriter(comet_config=comet_config)
    writer.add_hparams(hparam_dict=vars(args), metric_dict={})
    experiment = comet_ml.get_global_experiment()
    if experiment: experiment.set_name(args.name)


    # Load datasets and models
    dicLoader, _, _ = initialization(args)
    _, model = load_models(args.weight, nameNet=args.network, num_gpu=args.num_gpu, TrainMode=True)
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.1)
    lr_scheduler = CosineAnnealingLR(optimizer=optimizer,
                                             T_max=10,
                                             eta_min=1e-5,
                                             verbose=True)
    scaler = GradScaler()
    summary(model, (3, 128, 128), 64)



    best_acc, epochs = 0, args.epochs
    print('epochs={}'.format(epochs))
    is_best_acc = False
    step = 0
    cur_patience = 0

    # ------- START TRAINING ------- #
    for epoch in range(epochs):
        running_loss = []
        correct,total = 0,0
        model.train()

        for batch_idx, (inputs, targets) in enumerate(dicLoader['train_target']):
            step = (batch_idx+1) * (epoch+1)
            inputs, targets = inputs.cuda(), targets.cuda()

            # Forward
            with autocast(enabled=True):
                outputs = model(inputs)
                loss_main = criterion(outputs, targets)
                loss = loss_main

            # Learn
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Predictions
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += len(targets)

            # Log and print
            running_loss.append(loss_main.cpu().detach().numpy())
            writer.add_scalar('losses/loss', loss.item(), step, display_name="Loss")
            writer.add_scalar('acc/train_acc', correct/total, step, display_name="Train Accuracy")
            print("Train Epoch: {e:03d} Batch: {batch:05d}/{size:05d} | Loss: {loss:.4f}"
                            .format(e=epoch+1, batch=batch_idx, size=len(dicLoader['train_target']), loss=loss.item()))

        writer.add_scalar('losses/CE_loss', np.mean(running_loss), step, display_name="CE Loss")
        print("\nEpoch: {}/{} - CE_Loss: {:.4f} | ACC: {:.4f}".format(epoch+1, epochs, np.mean(running_loss), correct / total))
        lr_scheduler.step()


        # ----- Validation ------ #
        _, _, test_acc = Test(dicLoader['val_target'], model, criterion)
        writer.add_scalar('acc/val_acc', test_acc, step, display_name="Validation Accuracy")
        print(f"Test accuracy: {test_acc}")
        total_acc = test_acc

        is_best_acc = total_acc > best_acc
        if is_best_acc:
            cur_patience = 0
        else:
            cur_patience += 1

        if is_best_acc:
            print("Save model ...")
            best_acc = max(total_acc, best_acc)
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()
            },
            checkpoint = savepath,
            filename = 'epoch_{}'.format(epoch+1 if (epoch+1)%10==0 else ''),
            ACC_BEST=is_best_acc )
            if experiment: log_model(experiment, model, model_name=args.name)

        if args.early_stop and (cur_patience == args.patience):
              print("Early stopping ...")
              writer.close()
              return

    writer.close()



# Knowledge Distillation

The code incorporates elements derived from the code originally published in the research paper, which can be found here: https://github.com/alsgkals2/CoReD

In [7]:
%load_ext autoreload
%autoreload 2

import sys
from common_functions import *
from cored_functions import *
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from train_utils import loss_clampping, ReduceWeightOnPlateau

from torchsummary import summary
from tensorboardX import SummaryWriter
from comet_ml.integration.pytorch import log_model

def kd_train(args, log = None):

    # Init
    torch.cuda.empty_cache()
    device = 'cuda' if args.num_gpu else 'cpu'
    lr = args.lr
    KD_alpha = args.KD_alpha
    num_class = args.num_class
    num_store_per = args.num_store
    savepath = f"{args.checkpoints_dir}/{args.name_sources}_{args.name_target}/"
    savepath = savepath.replace('//','/')
    if not os.path.isdir(savepath):
        os.makedirs(savepath)
    print(f'save path: {savepath}')

    # Logger
    writer = SummaryWriter(comet_config=comet_config)
    writer.add_hparams(hparam_dict=vars(args), metric_dict={})
    experiment = comet_ml.get_global_experiment()
    if experiment: experiment.set_name(args.name)

    # Load datasets and models
    dicLoader, dicCoReD, dicSourceName = initialization(args)
    print("Dataset available in dicLoader: ", " / ".join([n for n in dicLoader]))
    print("Dataset available in dicCoReD: ", " / ".join([n for n in dicCoReD]))
    teacher_model, student_model = load_models(args.weight, args.network, num_gpu = args.num_gpu)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(student_model.parameters(), lr=lr, momentum=0.1)

    print("Teacher summary")
    summary(teacher_model, (3, args.resolution,args.resolution), 32)
    print("Student summary")
    summary(student_model, (3, args.resolution,args.resolution), 32)


    # Learning rate scheduler
    if args.lr_schedule == "cosine":
        print("Apply Cosine learning rate schedule")
        lr_scheduler = CosineAnnealingLR(optimizer=optimizer,
                                        T_max=10,
                                        eta_min=1e-5,
                                        verbose=True)
    elif args.lr_schedule == "onecycle":
        print("Apply OneCycle learning rate schedule")
        lr_scheduler = OneCycleLR(optimizer=optimizer,
                                    max_lr=lr,
                                    epochs=args.epochs,
                                    steps_per_epoch=len(dicLoader['train_target']),
                                    pct_start=0.05,
                                    total_steps=None,
                                    verbose=False)
    else:

        print(f"Input: {args.lr_schedule}, No learning rate schedule applied ... ")
        return
    watching_step = len(dicLoader['train_target']) // 10


    # Pre-evaluation
    print("Loading train target for correcting ...  ")
    _list_correct, _ = func_correct(teacher_model.to(device), dicCoReD['train_target_forCorrect'])
    _correct_loaders, already_correct_ratio = GetSplitLoaders_BinaryClasses(_list_correct, dicCoReD['train_target_dataset'], get_augs(args)[0], num_store_per)
    print("Ratio of already correctly predicted in training set: {:.3f}".format(already_correct_ratio))
    writer.add_scalar('start_acc', already_correct_ratio, 0, display_name="Target accuracy before traning")
    list_features = GetListTeacherFeatureFakeReal(teacher_model.module if ',' in args.num_gpu else teacher_model ,_correct_loaders, mode=args.network)
    list_features = np.array(list_features)
    print("List feature size: ", list_features.shape)

    # Initial validation
    _, _, test_acc = Test(dicLoader['val_target'], student_model, criterion, log = log, source_name = args.name_target)
    total_acc = test_acc
    print("[VAL Acc] Target: {:.2f}%".format( test_acc))
    cnt = 1
    for name in dicLoader:
        if 'val_dataset' in name or 'val_source' in name:
            if 'val_dataset' in name:
                source_name = dicSourceName[f'source{cnt}']
            else:
                source_name = dicSourceName['source']

            _, _, source_acc = Test(dicLoader[name], student_model, criterion, log = log, source_name = source_name)
            total_acc += source_acc
            print("[VAL Acc] Source {}-th: {:.2f}%".format(cnt, source_acc))
            cnt += 1

    print("[VAL Acc] Avg {:.2f}%\n Save initial model weight".format(total_acc / cnt))
    best_acc = total_acc
    save_checkpoint({
            'epoch': 0,
            'state_dict': student_model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict()},
                checkpoint = savepath,
                filename = '',
                ACC_BEST=True
                )

    is_best_acc = False
    cur_patience = 0 # Early stop and saving
    l_weight = 1.0 # reduce the conservation when performance does not gain much
    print(f"Start training in {args.epochs} epochs")


    # ------- START TRAINING ------- #
    for epoch in range(args.epochs):
        correct,total = 0,0
        teacher_model.eval()
        student_model.train()
        disp = {}

        for batch_idx, (inputs, targets) in enumerate(dicLoader['train_target']):
            # Load data
            step = (batch_idx+1) * (epoch+1)
            inputs = inputs.to(device).to(torch.float32)
            targets = targets.to(device).to(torch.long)
            if torch.isnan(inputs).any() or torch.isnan(targets).any():
                raise ValueError("There is Nan values in input or target")

            # Forward
            teacher_outputs = teacher_model(inputs)
            penul_ft, outputs = student_model(inputs, True)

            # Losses
            loss_main = criterion(outputs, targets)
            loss_kd = loss_fn_kd(outputs, targets, teacher_outputs)
            loss_kd = loss_clampping(loss_kd, 0, 1800)

            #REP loss
            list_features_std = [list(), list()]
            rep_ft_partitions = correct_binary_simple(inputs=inputs, penul_ft=penul_ft, outputs=outputs, targets=targets) # rep_ft_partitions : 5 x 2
            for j in range(num_store_per):
                for i in range(num_class):
                    if(np.count_nonzero(list_features[i][j])==0 or len(rep_ft_partitions[j][i])==0):
                      continue
                    feat = torch.stack(rep_ft_partitions[j][i], dim=0).mean(dim=0)
                    assert feat.size(-1) == 2048 or feat.size(-1) == 512 or feat.size(-1) == 1280
                    rep_loss = (feat.to(torch.float32)  - torch.tensor(list_features[i][j]).to(device).to(torch.float32)).pow(2).mean()
                    list_features_std[i].append(rep_loss)
            sne_loss = 0.0
            for fs in list_features_std:
                for ss in fs:
                    if ss.requires_grad:
                        sne_loss += ss
            sne_loss = loss_clampping(sne_loss, 0, 1) # REP Loss is clampped in this project

            # Total loss
            loss = loss_main  + l_weight*(loss_kd + sne_loss)
            sne_item = sne_loss if type(sne_loss) == float else sne_loss.item()

            # Log and display
            writer.add_scalar('losses/loss', loss.item(), step, display_name="Total Loss")
            writer.add_scalar('losses/loss_main', loss_main.item(), step, display_name="Main Loss")
            writer.add_scalar('losses/loss_kd', loss_kd.item(), step, display_name="KD Loss")
            writer.add_scalar('losses/loss_sne', sne_item, step, display_name="SNE Loss")
            disp["CE"] = loss_main.item()
            disp["KD"] = loss_kd.item() if loss_kd > 0 else 0.0
            disp["REP"] = sne_item if sne_loss > 0 else 0.0
            call = ' | '.join(["{}: {:.4f}".format(k, v) for k, v in disp.items()])
            print("Train Epoch: {e:03d} Batch: {batch:05d}/{size:05d} | Loss: {loss:.4f} | {call}"
                            .format(e=epoch+1, batch=batch_idx+1, size=len(dicLoader['train_target']), loss=loss.item(), call=call))

            # Learn!
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if args.lr_schedule == "onecycle":
                lr_scheduler.step()

            # Predictions
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += len(targets)

        if args.lr_schedule == "cosine":
            lr_scheduler.step()


        # ----- Validation ------ #

        # Current task
        _, _, test_acc = Test(dicLoader['val_target'], student_model, criterion, log = None, source_name = args.name_target)
        total_acc = test_acc
        print("[VAL Acc] Target: {:.2f}%".format( test_acc))
        writer.add_scalar(f'acc/{args.name_target}_val_acc', test_acc, step, display_name=f"Target {args.name_target} Validation Accuracy")

        # Past tasks
        cnt = 1
        for name in dicLoader:
            if 'val_dataset' in name or 'val_source' in name:
                if 'val_dataset' in name:
                    source_name = dicSourceName[f'source{cnt}']
                else:
                    source_name = dicSourceName['source']

                _, _, source_acc = Test(dicLoader[name], student_model, criterion, log = None, source_name = source_name)
                total_acc += source_acc
                print("[VAL Acc] Source {}-th: {:.2f}%".format(cnt, source_acc))
                writer.add_scalar(f'acc/{source_name}_val_acc', source_acc, step, display_name=f"{source_name} Validation Accuracy")
                cnt += 1
        print("[VAL Acc] Avg {:.2f}%".format(total_acc / cnt))
        writer.add_scalar('acc/val_acc', total_acc/cnt, step, display_name="Average Validation Accuracy")

        # Early stop
        is_best_acc = total_acc > best_acc
        if is_best_acc:
                print("VAL Acc improve from {:.2f}% to {:.2f}%".format(best_acc/cnt, total_acc/cnt))
                cur_patience = 0
        else:
            cur_patience += 1
        if args.loss_schedule and (cur_patience > 0 and cur_patience % 4 == 0):
                l_weight = ReduceWeightOnPlateau(l_weight, args.decay_factor)

        # Save
        best_acc = max(total_acc,best_acc)
        if  is_best_acc:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': student_model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()},
            checkpoint = savepath,
            filename = 'epoch_{}'.format( epoch+1 if (epoch+1)%10==0 else ''),
            ACC_BEST=is_best_acc
            )
            if experiment: log_model(experiment, student_model, model_name=args.name)
            print('Save best model' if is_best_acc else f'Save checkpoint model @ {epoch+1}')
        if args.early_stop and (cur_patience == args.patience):
            print("Early stopping ...")
            writer.close()
            return

    writer.close()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Elastic Weight Consolidation


In [8]:
from copy import deepcopy

import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data

class EWC(object):
    def __init__(self, model: nn.Module, datasets, optimizer , criterion, device, n_sample_batches):

        self.model = model
        self.datasets = datasets
        self.optimizer = optimizer
        self.device = device
        self.criterion = criterion
        self.n_sample_batches = n_sample_batches

        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = self._diag_fisher()

        for n, p in deepcopy(self.params).items():
            self._means[n] = p.data.to(self.device)

    # Compute the diagonal Fisher information matrix
    def _diag_fisher(self):
        self.model.eval()
        precision_matrices = {}
        for n, p in deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = p.data.to(self.device)

        for dataset in self.datasets:
          for batch_idx, (input, target) in enumerate(dataset):
            if batch_idx < self.n_sample_batches:
              input = input.to(self.device)
              target = target.to(self.device)
              self.optimizer.zero_grad()
              output = self.model(input)
              loss = self.criterion(output, target)
              loss.backward()

              for n, p in self.model.named_parameters():
                if p.grad is not None:
                  precision_matrices[n].data += p.grad.data.clone().pow(2) / len(dataset)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        self.model.train()
        return precision_matrices

    # Compute the EWC penalty
    def penalty(self, model: nn.Module):
        loss = 0.0
        for n, p in model.named_parameters():
          if n not in self._precision_matrices:
            continue
          _loss = self._precision_matrices[n] * ((p - self._means[n]) ** 2)
          loss += _loss.sum()
        return loss


In [10]:
%load_ext autoreload
%autoreload 2

import sys
from common_functions import *
from cored_functions import *
import torch.optim as optim
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchsummary import summary
from tensorboardX import SummaryWriter
from comet_ml.integration.pytorch import log_model

def ewc_train(args, log = None):

    # Init
    torch.cuda.empty_cache()
    device = 'cuda' if args.num_gpu else 'cpu'
    lr = args.lr
    num_class = args.num_class
    savepath = f"{args.checkpoints_dir}/{args.name_sources}_{args.name_target}/"
    savepath = savepath.replace('//','/')
    if not os.path.isdir(savepath):
        os.makedirs(savepath)

    # Logger
    writer = SummaryWriter(comet_config=comet_config)
    writer.add_hparams(hparam_dict=vars(args), metric_dict={})
    experiment = comet_ml.get_global_experiment()
    if experiment: experiment.set_name(args.name)

    # Load datasets and models
    dicLoader, _, dicSourceName = initialization(args)
    print("Dataset available in dicLoader: ", " / ".join([n for n in dicLoader]))
    _, model = load_models(args.weight, args.network, num_gpu = args.num_gpu)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.1)
    print("Model summary")
    summary(model, (3, args.resolution,args.resolution), 32)


    # Initial validation
    _, _, test_acc = Test(dicLoader['val_target'], model, criterion, log = log, source_name = args.name_target)
    writer.add_scalar('start_acc', test_acc, 0, display_name="Target accuracy before traning")
    print("Start Target Validation ACC: {:.2f}%".format(test_acc))


    # Get past tasks data loaders
    train_loader = dicLoader['train_target']
    old_tasks_loaders = []
    for name in dicLoader:
      if 'val_dataset' in name or 'val_source' in name:
        old_tasks_loaders.append(dicLoader[name])

    # Compute weight importance
    ewc = EWC(model, old_tasks_loaders, optimizer, criterion, device, args.sample_batch)


     # ------- START TRAINING ------- #
    is_best_acc = False
    best_acc = 0
    cur_patience = 0
    print(f"Start training in {args.epochs} epochs")
    for epoch in range(args.epochs):
        print(f"\n\n---------- Starting epoch {epoch} ----------")
        correct,total = 0,0
        model.train()

        correct, total = 0, 0
        tot_ewc_loss, tot_task_loss = 0.0, 0.0
        for inputs, targets in train_loader:
            # Load data
            inputs = inputs.to(device).to(torch.float32)
            targets = targets.to(device).to(torch.long)

             # Forward
            outputs = model(inputs)

             # Losses
            task_loss = criterion(outputs, targets)
            ewc_loss = args.importance * ewc.penalty(model)
            loss = task_loss + ewc_loss

            # Learn
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Predict
            tot_task_loss += task_loss.item()
            tot_ewc_loss += ewc_loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
            total += len(targets)


        # Metrics
        epoch_task_loss = tot_task_loss / len(train_loader)
        epoch_ewc_loss = tot_ewc_loss / len(train_loader)
        tot_loss = epoch_task_loss + epoch_ewc_loss
        acc = (correct / total) * 100

        # Logging
        writer.add_scalar('losses/loss', tot_loss, epoch, display_name="Total Loss")
        writer.add_scalar('losses/task_loss', epoch_task_loss, epoch, display_name="Task Loss")
        writer.add_scalar('losses/ewc_loss', epoch_ewc_loss, epoch, display_name="EWC Loss")
        writer.add_scalar('acc/train_acc', acc, epoch, display_name="Train ACC")
        print(f"Train Epoch: {epoch:03d} |Acc {acc:0.5f} | Loss: {tot_loss:.5f} | Task Loss {epoch_task_loss:.5f} | EWC Loss: {epoch_ewc_loss:.5f}")


        # ---- START VALIDATION ---- #

        # Target
        _, _, test_acc = Test(dicLoader['val_target'], model, criterion, log = None, source_name = args.name_target)
        total_acc = test_acc
        writer.add_scalar(f'acc/{args.name_target}_val_acc', test_acc, epoch, display_name=f"Target {args.name_target} Validation Accuracy")
        print("[VAL Acc] Target: {:.2f}%".format(test_acc))

        # Old tasks
        cnt = 1
        for name in dicLoader:
          if 'val_dataset' in name or 'val_source' in name:
            if 'val_dataset' in name:
              source_name = dicSourceName[f'source{cnt}']
            else:
              source_name = dicSourceName['source']

            _, _, source_acc = Test(dicLoader[name], model, criterion, log = None, source_name = source_name)
            total_acc += source_acc
            print("[VAL Acc] Source {}-th: {:.2f}%".format(cnt, source_acc))
            writer.add_scalar(f'acc/{source_name}_val_acc', source_acc, epoch, display_name=f"{source_name} Validation Accuracy")
            cnt += 1
        print("[VAL Acc] Avg {:.2f}%".format(total_acc / cnt))
        writer.add_scalar('acc/val_acc', total_acc/cnt, epoch, display_name="Average Validation Accuracy")


        # Evaluate performances
        is_best_acc = total_acc > best_acc
        best_acc = max(total_acc,best_acc)
        if is_best_acc:
          print("VAL Acc improve from {:.2f}% to {:.2f}%".format(best_acc/cnt, total_acc/cnt))
          cur_patience = 0
        else:
          cur_patience += 1

        # Log validation
        if  is_best_acc:
          save_checkpoint({
              'epoch': epoch + 1,
              'state_dict': model.state_dict(),
              'best_acc': best_acc,
              'optimizer': optimizer.state_dict()},
          checkpoint = savepath,
          filename = 'epoch_{}'.format( epoch+1 if (epoch+1)%10==0 else ''),
          ACC_BEST=is_best_acc
          )
          if experiment: log_model(experiment, model, model_name=args.name) #Comet
          print('Save best model' if is_best_acc else f'Save checkpoint model @ {epoch+1}')

        # Early stop
        if args.early_stop and (cur_patience == args.patience):
                print("Early stopping ...")
                writer.close()
                return

    writer.close()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Evaluate


In [11]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import numpy as np
from common_functions import initialization, load_models, AverageMeter
from sklearn.metrics import classification_report, roc_auc_score, accuracy_score, average_precision_score
from tqdm import tqdm


def evaluate(args, global_writer=None):

    # Config
    setattr(args,"name_sources", "")
    setattr(args,"name_target", "")


    # Logger
    if global_writer is None:
      writer = SummaryWriter(comet_config=comet_config)
    else:
      writer = global_writer
    writer.add_hparams(hparam_dict=vars(args), metric_dict={})
    experiment = comet_ml.get_global_experiment()
    if experiment: experiment.set_name(args.name)


    if "real_ds" in args.ds_cfg:
      writer.add_hparams(hparam_dict=get_dataset_cfg(args.ds_cfg["real_ds"]), metric_dict={})


    # Load model
    _, model = load_models(args.weight, args.network, args.num_gpu, not args.test)
    criterion = nn.CrossEntropyLoss().cuda()

    # Load datasets
    tot_avg_acc, real_avg_acc, fake_avg_acc = 0.0, 0.0 ,0.0
    for ds_name in args.ds_cfg["fake_ds"]:
      writer.add_hparams(hparam_dict=get_dataset_cfg(ds_name, args), metric_dict={})
      data_folder = f"{args.dataroot}/{ds_name}/test"
      setattr(args,"data", data_folder)
      dicLoader,_, dicSourceName = initialization(args)


      for key, name in zip(dicLoader, dicSourceName):
        # Init
        global best_acc
        correct, total =0,0
        losses = AverageMeter()
        arc = AverageMeter()
        acc_real = AverageMeter()
        acc_fake = AverageMeter()
        sum_of_AUROC=[]
        target=[]
        output = []
        y_true=np.zeros((0,2),dtype=np.int8)
        y_pred=np.zeros((0,2),dtype=np.int8)

        with torch.no_grad():
          model.eval()
          model.cuda()

          for (inputs, targets) in tqdm(dicLoader[key], ncols=50):
              # Predict
              inputs, targets = inputs.to('cuda'), targets.to('cuda')
              outputs = model(inputs)
              loss = criterion(outputs, targets)
              _, predicted = torch.max(outputs, 1)
              correct = (predicted == targets).squeeze()
              total += len(targets)
              losses.update(loss.data.tolist(), inputs.size(0))
              _y_pred = outputs.cpu().detach()
              _y_gt = targets.cpu().detach().numpy()
              acc = [0, 0]
              class_total = [0, 0]
              for i in range(len(targets)):
                  label = targets[i]
                  acc[label] += 1 if correct[i].item() == True else 0
                  class_total[label] += 1

              losses.update(loss.data.tolist(), inputs.size(0))
              if (class_total[0] != 0):
                  acc_real.update(acc[0] / class_total[0])
              if (class_total[1] != 0):
                  acc_fake.update(acc[1] / class_total[1])

              target.append(_y_gt)
              output.append(_y_pred.numpy()[:,1])
              auroc=None
              try:
                  auroc = roc_auc_score(_y_gt, outputs[:,1].cpu().detach().numpy())
              except ValueError:
                  pass
              sum_of_AUROC.append(auroc)
              _y_true = np.array(torch.zeros(targets.shape[0],2), dtype=np.int8)
              _y_gt = _y_gt.astype(int)
              for _ in range(len(targets)):
                  _y_true[_][_y_gt[_]] = 1
              y_true = np.concatenate((y_true,_y_true))
              a = _y_pred.argmax(1)
              _y_pred = np.array(torch.zeros(_y_pred.shape).scatter(1, a.unsqueeze(1), 1),dtype=np.int8)
              y_pred = np.concatenate((y_pred,_y_pred))

          n_real_samples = np.count_nonzero(y_true, axis=0)[0]
          n_fake_samples = np.count_nonzero(y_true, axis=0)[1]
          acc = accuracy_score(y_true, y_pred)
          ap = average_precision_score(y_true, y_pred)

          result = classification_report(y_true, y_pred,
                                              labels=None,
                                              target_names=None,
                                              sample_weight=None,
                                              digits=4,
                                              output_dict=False,
                                              zero_division='warn')


          print(f"\nLoss:{losses.avg:.4f} | Acc:{acc:.4f} | Acc Real:{acc_real.avg:.4f} | Acc Fake:{acc_fake.avg:.4f} | Ap:{ap:.4f}")
          print(f'Num reals: {n_real_samples}, Num fakes: {n_fake_samples}')
          print("\n\n",result)

          tot_avg_acc += acc
          real_avg_acc += acc_real.avg
          fake_avg_acc += acc_fake.avg

          if experiment is not None:
            experiment.log_metrics(
                  {
                      "real_acc": acc_real.avg*100.,
                      "fake_acc": acc_fake.avg*100.,
                      "tot_acc": acc*100.,
                      "ap": ap*100.,
                  },
                  prefix=str(ds_name)
              )
            experiment.log_metrics(
                  {
                      "num_reals": n_real_samples,
                      "num_fakes": n_fake_samples,
                  },
                  prefix=str(ds_name)
              )

    total_ds = len(args.ds_cfg["fake_ds"])
    if experiment is not None:
              experiment.log_metrics(
                    {
                        "real_acc": (real_avg_acc/total_ds)*100.,
                        "fake_acc": (fake_avg_acc/total_ds)*100.,
                        "tot_acc": (tot_avg_acc/total_ds)*100.,
                    },
                    prefix="avg_acc"
                )
    print(f"Avg: | Acc:{tot_avg_acc/total_ds:.4f} | Acc Real:{real_avg_acc/total_ds:.4f} | Acc Fake:{fake_avg_acc/total_ds:.4f}")
    if global_writer is None: writer.close()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Workspace

In this section it is possible to run trainings and evaluations

## Traning


1.   Configurate the training
2.   Build the dataset
3.   Train!



### Transfer Learning


In [13]:
from types import SimpleNamespace

# Set the datasets, it will be used to build the dataset folder
# Available datasets: biggan,crn,cyclegan,faceforensics,gaugan,glow,imle,san,stargan,stylegan,whichfaceisreal,wild,diffusionshort
ds_cfg = {
    "type":         "cddb",                 # cddb, guarnera
    #"real_ds":      "ffhq",                # used only for guarnera: ffhq, celeba
    "fake_ds":      ["gaugan"],             # for task 1 use only one dataset
}

train_cfg = {
    "name":         "ttl_gau",                # Used for tagging the experiment on logs
    "data":         "/dataset",             # Folder containing the EXTRACTED datasets
    "weight":       "",                     # OPTIONAL: load weights from file .pth, if empty preweights will be downloaded
    "network":      "ResNet",               # Backbone: ResNet, ResNet18, Xception, MobileNet2
    "name_sources": "gaugan",               # Training dataset
    "name_target":   "gaugan",              # Training dataset, must be the same as name_sources
    "checkpoints_dir": "./checkpoints/TL/demo", # Save folder
    "lr_schedule":  "cosine",               # cosine
    "test":         False,                  # False
    "use_gpu":      True,                   # True, False
    "num_gpu":      "0",                    # GPU id, used only if use_gpu=True
    "loss_schedule": True,                  # True, False
    "num_class":    2,                      # classification classes, 2 for binary classification
    "crop":         True,                   # Crop images instead of resize
    "flip":         False,                  # Random flip augmentation
    "resolution":   128,                    # Crop/resize resolution
    "lr":           0.005,                  # Learning rate
    "batch_size":   64,                     # Batch size
    "epochs":       200,                    # Traning epochs
    "early_stop":   True,                   # True, False
    "patience":     25,                     # Early stop patience
    "ds_cfg":        ds_cfg,
}

cfg = SimpleNamespace(**train_cfg)

In [14]:
build_cddb_dataset(cfg.ds_cfg["fake_ds"], erase=True, continual_mode=True)

Extracting dataset gaugan in /dataset/gaugan
TOT (for each class): 5000, train 3000, val 1000, test 1000



In [15]:
tl_train(cfg)

save path : ./checkpoints/TL/demo/gaugan/



------ Creating Loaders ------
GPU num is 0

===> Starting Task 1 loader from /dataset
Source: gaugan
Target: gaugan

---DATASET PATHS---
Train dir: /dataset/gaugan/train
Validation Source dir /dataset/gaugan/val
Validation Target dir /dataset/gaugan/val



 ------ Loading models ------


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 186MB/s]


Adjusting learning rate of group 0 to 5.0000e-03.
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [64, 64, 64, 64]           9,408
       BatchNorm2d-2           [64, 64, 64, 64]             128
              ReLU-3           [64, 64, 64, 64]               0
         MaxPool2d-4           [64, 64, 32, 32]               0
            Conv2d-5           [64, 64, 32, 32]           4,096
       BatchNorm2d-6           [64, 64, 32, 32]             128
              ReLU-7           [64, 64, 32, 32]               0
            Conv2d-8           [64, 64, 32, 32]          36,864
       BatchNorm2d-9           [64, 64, 32, 32]             128
             ReLU-10           [64, 64, 32, 32]               0
           Conv2d-11          [64, 256, 32, 32]          16,384
      BatchNorm2d-12          [64, 256, 32, 32]             512
           Conv2d-13          [64, 256, 32, 32]      

100%|██████████| 32/32 [00:05<00:00,  5.48it/s]



Test | Loss:0.5603 | MainLoss:0.5603 | top:73.8500
Test accuracy: 73.85000000000001
Save model ...
Train Epoch: 002 Batch: 00000/00094 | Loss: 0.5465
Train Epoch: 002 Batch: 00001/00094 | Loss: 0.5540
Train Epoch: 002 Batch: 00002/00094 | Loss: 0.5722
Train Epoch: 002 Batch: 00003/00094 | Loss: 0.5716
Train Epoch: 002 Batch: 00004/00094 | Loss: 0.5952
Train Epoch: 002 Batch: 00005/00094 | Loss: 0.5883
Train Epoch: 002 Batch: 00006/00094 | Loss: 0.5417
Train Epoch: 002 Batch: 00007/00094 | Loss: 0.5220
Train Epoch: 002 Batch: 00008/00094 | Loss: 0.6050
Train Epoch: 002 Batch: 00009/00094 | Loss: 0.5457
Train Epoch: 002 Batch: 00010/00094 | Loss: 0.5994
Train Epoch: 002 Batch: 00011/00094 | Loss: 0.4979
Train Epoch: 002 Batch: 00012/00094 | Loss: 0.6124
Train Epoch: 002 Batch: 00013/00094 | Loss: 0.5545
Train Epoch: 002 Batch: 00014/00094 | Loss: 0.5787
Train Epoch: 002 Batch: 00015/00094 | Loss: 0.5665
Train Epoch: 002 Batch: 00016/00094 | Loss: 0.5628
Train Epoch: 002 Batch: 00017/000

100%|██████████| 32/32 [00:05<00:00,  5.63it/s]



Test | Loss:0.4583 | MainLoss:0.4583 | top:79.1000
Test accuracy: 79.10000000000001
Save model ...


### Knowledge distillation

In [None]:
from types import SimpleNamespace

# Set the datasets, it will be used to build the dataset folder
# Available datasets: biggan,crn,cyclegan,faceforensics,gaugan,glow,imle,san,stargan,stylegan,whichfaceisreal,wild",diffusionshort
ds_cfg = {
    "type":         "cddb",                 # cddb, guarnera
    #"real_ds":      "ffhq",                # used only for guarnera: ffhq, celeba
    "fake_ds":      ["gaugan", "biggan", "cyclegan"] # List of all datasets from first to current task
}

train_cfg = {
    "name":         "tkd_gau_big_cycle",    # Used for tagging the experiment on logs
    "data":         "/dataset",             # Folder containing the EXTRACTED datasets
    "weight":       "./checkpoints/KD/demo/gaugan_biggan/model_best_accuracy.pth", # load weights of task i-1 from file .pth
    "network":      "ResNet",               # Backbone: ResNet, ResNet18, Xception, MobileNet2
    "name_sources": "gaugan_biggan",        # Ordered list of previous task: dataset1_dataset2_dataseti-1
    "name_target":   "cyclegan",            # Task i dataset
    "checkpoints_dir": "./checkpoints/KD/demo", # Save folder (the task subfolder is automatically created)
    "lr_schedule":  "cosine",               # cosine, onecycle
    "test":         False,                  # False
    "use_gpu":      True,                   # True, False
    "num_gpu":      "0",                    # GPU id, used only if use_gpu=True
    "loss_schedule": True,                  # True, False
    "num_class":    2,                      # classification classes, 2 for binary classification
    "crop":         True,                   # Crop images instead of resize
    "flip":         False,                  # Random flip augmentation
    "resolution":   128,                    # Crop/resize resolution
    "KD_alpha":     0.5,                    # alpha factor for kd loss
    "num_store":    5,                      # Stores for representation loss
    "lr":           0.005,                  # Learning rate
    "batch_size":   64,                     # Batch size
    "epochs":       200,                    # Traning epochs
    "early_stop":   True,                   # True, False
    "patience":     25,                     # Early stop patience
    "ds_cfg":       ds_cfg,
}

cfg = SimpleNamespace(**train_cfg)

In [None]:
# limit: Limit the number of extracted images (test, validation, test)
build_cddb_dataset(cfg.ds_cfg["fake_ds"], erase=True, continual_mode=True, limit=(1000,1000,1000))

Extracting dataset gaugan in /dataset/gaugan
TOT (for each class): 30, train 10, val 10, test 10

Extracting dataset biggan in /dataset/biggan
TOT (for each class): 30, train 10, val 10, test 10

Extracting dataset cyclegan in /dataset/cyclegan
TOT (for each class): 30, train 10, val 10, test 10



In [None]:
kd_train(cfg)

### Elastic Weight Consolidation



In [19]:
from types import SimpleNamespace

# Set the datasets, it will be used to build the dataset folder
# Available datasets: biggan,crn,cyclegan,faceforensics,gaugan,glow,imle,san,stargan,stylegan,whichfaceisreal,wild",diffusionshort
ds_cfg = {
    "type":         "cddb",                 # cddb, guarnera
    #"real_ds":      "ffhq",                # used only for guarnera: ffhq, celeba
    "fake_ds":      ["gaugan", "biggan"] # List of all datasets from first to current task
}

train_cfg = {
    "name":         "tewc_gau_big",         # Used for tagging the experiment on logs
    "data":         "/dataset",             # Folder containing the EXTRACTED datasets
    "weight":       "./checkpoints/EWC/demo/gaugan/model_best_accuracy.pth", # load weights of task i-1 from file .pth
    "network":      "ResNet",               # Backbone: ResNet, ResNet18, Xception, MobileNet2
    "name_sources": "gaugan",               # Ordered list of previous task: dataset1_dataset2_dataseti-1
    "name_target":   "biggan",              # Task i dataset
    "checkpoints_dir": "./checkpoints/EWC/demo", # Save folder (the task subfolder is automatically created)
    "test":         False,                  # False
    "use_gpu":      True,                   # True, False
    "num_gpu":      "0",                    # GPU id, used only if use_gpu=True
    "num_class":    2,                      # classification classes, 2 for binary classification
    "crop":         True,                   # Crop images instead of resize
    "flip":         False,                  # Random flip augmentation
    "resolution":   128,                    # Crop/resize resolution
    "importance":  10,                      # Importance factor for ewc loss
    "sample_batch":  4,                     # Number of batches for each past task used for inportance matrix computation
    "lr":           0.005,                  # Learning rate
    "batch_size":   64,                     # Batch size
    "epochs":       2,                    # Traning epochs
    "early_stop":   True,                   # True, False
    "patience":     25,                     # Early stop patience
    "ds_cfg":       ds_cfg,
}

cfg = SimpleNamespace(**train_cfg)

In [17]:
build_cddb_dataset(cfg.ds_cfg["fake_ds"], erase=True, continual_mode=True)

Extracting dataset gaugan in /dataset/gaugan
TOT (for each class): 5000, train 3000, val 1000, test 1000

Extracting dataset biggan in /dataset/biggan
TOT (for each class): 2000, train 1200, val 400, test 400



In [20]:
ewc_train(cfg)




------ Creating Loaders ------
GPU num is 0

===> Starting Task 1 loader from /dataset
Source: gaugan
Target: biggan

---DATASET PATHS---
Train dir: /dataset/biggan/train
Validation Source dir /dataset/gaugan/val
Validation Target dir /dataset/biggan/val
Dataset available in dicLoader:  train_target / val_source / val_target / val_target_mix



 ------ Loading models ------
Loading ResNet from ./checkpoints/TL/demo/gaugan/model_best_accuracy.pth
Loaded
Model summary
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [32, 64, 64, 64]           9,408
       BatchNorm2d-2           [32, 64, 64, 64]             128
              ReLU-3           [32, 64, 64, 64]               0
         MaxPool2d-4           [32, 64, 32, 32]               0
            Conv2d-5           [32, 64, 32, 32]           4,096
       BatchNorm2d-6           [32, 64, 32, 32]             128
             

100%|██████████| 13/13 [00:02<00:00,  5.14it/s]


Test | Loss:0.8666 | MainLoss:0.8666 | top:51.6250
Start Target Validation ACC: 51.62%





Start training in 2 epochs


---------- Starting epoch 0 ----------
Train Epoch: 000 |Acc 57.29167 | Loss: 0.72133 | Task Loss 0.72116 | EWC Loss: 0.00017
===> Starting the dataset biggan


100%|██████████| 13/13 [00:02<00:00,  5.58it/s]



Test | Loss:0.7006 | MainLoss:0.7006 | top:55.7500
[VAL Acc] Target: 55.75%
===> Starting the dataset gaugan


100%|██████████| 32/32 [00:07<00:00,  4.12it/s]



Test | Loss:0.5350 | MainLoss:0.5350 | top:73.9000
[VAL Acc] Source 1-th: 73.90%
[VAL Acc] Avg 64.83%
VAL Acc improve from 64.83% to 64.83%
Save best model


---------- Starting epoch 1 ----------
Train Epoch: 001 |Acc 66.83333 | Loss: 0.61079 | Task Loss 0.61016 | EWC Loss: 0.00063
===> Starting the dataset biggan


100%|██████████| 13/13 [00:03<00:00,  3.31it/s]



Test | Loss:0.6523 | MainLoss:0.6523 | top:60.7500
[VAL Acc] Target: 60.75%
===> Starting the dataset gaugan


100%|██████████| 32/32 [00:06<00:00,  5.21it/s]



Test | Loss:0.5738 | MainLoss:0.5738 | top:69.7500
[VAL Acc] Source 1-th: 69.75%
[VAL Acc] Avg 65.25%
VAL Acc improve from 65.25% to 65.25%
Save best model


## Evaluation

In [None]:
from types import SimpleNamespace

# Set the datasets, it will be used to build the dataset folder
# Available datasets: biggan,crn,cyclegan,faceforensics,gaugan,glow,imle,san,stargan,stylegan,whichfaceisreal,wild,diffusionshort
ds_cfg = {
    "type":         "cddb",                 # cddb, guarnera
    #"real_ds":      "ffhq",                # used only for guarnera: ffhq, celeba
    "fake_ds":      ["gaugan", "biggan", "cyclegan"] # List of all datasets to test
}


evaluate_cfg = {

    "name":         "ekd_gau_big_cycle",    # Used for tagging the experiment on logs
    "dataroot":     "/dataset",             # Folder containing the EXTRACTED datasets
    "weight":       "./checkpoints/KD/demo/gaugan_biggan_cyclegan/model_best_accuracy.pth", # load weights of task i from file .pth
    "network":      "ResNet",               # Backbone: ResNet, ResNet18, Xception, MobileNet2
    "test":         True,                   # True
    "use_gpu":      True,                   # True, False
    "num_gpu":      "0",                    # GPU id, used only if use_gpu=True
    "crop":         True,                   # Crop images instead of resize
    "flip":         False,                  # Random flip augmentation
    "resolution":   128,                    # Crop/resize resolution
    "num_class":    2,                      # classification classes, 2 for binary classification
    "batch_size":   64,                     # Batch size
    "ds_cfg":        ds_cfg,
}

evaluate_cfg = SimpleNamespace(**evaluate_cfg)

In [None]:
build_cddb_dataset(evaluate_cfg.ds_cfg["fake_ds"], erase=True, continual_mode=True)

Extracting dataset gaugan in /dataset/gaugan
TOT (for each class): 5000, train 3000, val 1000, test 1000

Extracting dataset biggan in /dataset/biggan
TOT (for each class): 2000, train 1200, val 400, test 400

Extracting dataset cyclegan in /dataset/cyclegan
TOT (for each class): 1319, train 784, val 262, test 273



In [None]:
evaluate(evaluate_cfg)




 ------ Loading models ------
Loading ResNet18 from ./checkpoints/KD/demo/gaugan_biggan_cyclegan/model_best_accuracy.pth
Loaded



------ Creating Loaders ------
GPU num is 0

===> Starting Task 1 loader from /dataset/gaugan/test
Source: 
Target: 


100%|█████████████| 32/32 [00:09<00:00,  3.55it/s]



Loss:0.6381 | Acc:0.6350 | Acc Real:0.7534 | Acc Fake:0.5156 | Ap:0.5868
Num reals: 1000, Num fakes: 1000


               precision    recall  f1-score   support

           0     0.6089    0.7550    0.6741      1000
           1     0.6776    0.5150    0.5852      1000

   micro avg     0.6350    0.6350    0.6350      2000
   macro avg     0.6433    0.6350    0.6297      2000
weighted avg     0.6433    0.6350    0.6297      2000
 samples avg     0.6350    0.6350    0.6350      2000




------ Creating Loaders ------
GPU num is 0

===> Starting Task 1 loader from /dataset/biggan/test
Source: 
Target: 


100%|█████████████| 13/13 [00:02<00:00,  5.30it/s]



Loss:0.8496 | Acc:0.4875 | Acc Real:0.6017 | Acc Fake:0.3656 | Ap:0.4939
Num reals: 400, Num fakes: 400


               precision    recall  f1-score   support

           0     0.4899    0.6050    0.5414       400
           1     0.4837    0.3700    0.4193       400

   micro avg     0.4875    0.4875    0.4875       800
   macro avg     0.4868    0.4875    0.4803       800
weighted avg     0.4868    0.4875    0.4803       800
 samples avg     0.4875    0.4875    0.4875       800




------ Creating Loaders ------
GPU num is 0

===> Starting Task 1 loader from /dataset/cyclegan/test
Source: 
Target: 


100%|███████████████| 9/9 [00:01<00:00,  5.29it/s]



Loss:0.4771 | Acc:0.7766 | Acc Real:0.7253 | Acc Fake:0.8245 | Ap:0.7156
Num reals: 273, Num fakes: 273


               precision    recall  f1-score   support

           0     0.8082    0.7253    0.7645       273
           1     0.7508    0.8278    0.7875       273

   micro avg     0.7766    0.7766    0.7766       546
   macro avg     0.7795    0.7766    0.7760       546
weighted avg     0.7795    0.7766    0.7760       546
 samples avg     0.7766    0.7766    0.7766       546

Avg: | Acc:0.6330 | Acc Real:0.6935 | Acc Fake:0.5685
