In [None]:
from init_notebook import *
from src.train.experiment import load_experiment_trainer
from functools import partial

In [None]:
def plot(ds, count=16*16):
    batch = next(iter(DataLoader(ds, batch_size=count)))
    if isinstance(batch, (tuple, list)):
        images = batch[0]
        for b in batch[1:]:
            if isinstance(b, torch.Tensor) and b.shape[-3:] == images.shape[-3:]:
                images = torch.cat([images, b], dim=0)
    else:
        images = batch
        
    display(VF.to_pil_image(make_grid(images, nrow=int(math.sqrt(count)))))


## load trainer and model

In [None]:
trainer = load_experiment_trainer("../experiments/img2img/extrusion/extrusion-simple-adv.yml", device="auto")
assert trainer.load_checkpoint("snapshot")
model = trainer.model

### create adversarial

In [None]:
def create_adversarial(
        model: nn.Module,
        image: torch.Tensor,
        max_adversarial_distance: float = 0.001,
        #min_output_distance: float = 0.1,
        steps: int = 1000,
        lr: float = 0.0001,
):
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
        device = p.device

    image = image[:, :1].to(device)
    
    with torch.no_grad():
        ref_image = model(image.repeat(1, 3, 1, 1))

    display(VF.to_pil_image(make_grid(ref_image)))
    
    ad_image = nn.Parameter(image.clone())
    
    optimizer = torch.optim.AdamW([ad_image], lr=lr)

    history = []
    with tqdm(total=steps) as progress:
        for step in range(steps):
    
            output = model(ad_image.repeat(1, 3, 1, 1).clamp(0, 1)).clamp(0, 1)
    
            output_distance = F.l1_loss(ref_image, output)
            adversarial_distance = F.l1_loss(image, ad_image)
    
            loss = (
                #(adversarial_distance - max_adversarial_distance).clamp_min(0)
                #- 0.01 * output_distance
                adversarial_distance - output_distance / 4
            )
    
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            entry = {"loss": float(loss), "od": float(output_distance), "ad": float(adversarial_distance)}
            history.append(entry)
            progress.desc = " ".join(f"{k}={v}" for k, v in entry.items())
            progress.update()
            
    ad_image = ad_image.detach().cpu()
    display(VF.to_pil_image(make_grid(image)))
    display(VF.to_pil_image(make_grid(ad_image)))
    display(VF.to_pil_image(make_grid(output)))
    return ad_image, pd.DataFrame(history)
    
ad_image, hist = create_adversarial(
    trainer.model,
    next(iter(trainer.validation_loader))[1][:8],
    steps=10_000,
)
df = hist.copy()
df.index = pd.DatetimeIndex(df.index * 1_000_000_000)
df.resample("50s").mean().plot()

## train whole image

In [None]:
def create_adversarial(
        model: nn.Module,
        image: torch.Tensor,
        adversarial_distance_factor: float = 2.,
        steps: int = 1000,
        lr: float = 0.003,
        batch_size: int = 8,
        shape: Tuple[int, int] = (64, 64),
):
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
        device = p.device

    C, H, W = image.shape
    print(f"image: {W}x{H}")
    image = image[:1]
    ad_image = nn.Parameter(image.clone().to(device))
    history = []

    optimizer = torch.optim.AdamW([ad_image], lr=lr)
    
    #if not isinstance(adversarial_distance_factor, (list, tuple)):
    #    adversarial_distance_factor = tuple(adversarial_distance_factor)

    def yield_crops():
        t = 0.
        with tqdm(total=steps * batch_size) as progress:
            px, py = 0, 0
            for step in range(steps):
                if step % 500 == 0:
                    px += 512
                    if px >= W:
                        px = 0
                        py += 512
                t = step / (steps - 1)
                batches = [[], []]
                for i in range(batch_size):
                    x = px + random.randrange(min(512, W - px) - shape[-1])
                    y = py + random.randrange(min(512, H - py) - shape[-2])
                    #x, y = 500 + min(x, 100), 500 + min(y, 100)
                    batches[0].append(   image[:, y: y + shape[-2], x: x + shape[-1]].to(device))
                    batches[1].append(ad_image[:, y: y + shape[-2], x: x + shape[-1]])
                yield t, tuple(
                    torch.concat([b.unsqueeze(0) for b in batch])
                    for batch in batches
                )

                if history:
                    progress.desc = " ".join(f"{k}={v}" for k, v in history[-1].items())
                
                progress.update(batch_size)

    try:
        for t, (image_crops, ad_image_crops) in yield_crops():
            with torch.no_grad():
                ref_crops = model(image_crops.repeat(1, C, 1, 1)).clamp(0, 1)
            output_crops = model(ad_image_crops.repeat(1, C, 1, 1).clamp(0, 1)).clamp(0, 1)
    
            output_distance = F.l1_loss(output_crops, ref_crops)

            if callable(adversarial_distance_factor):
                adf = adversarial_distance_factor(t)
            else:
                adf = adversarial_distance_factor
            #adf = adversarial_distance_factor[0] * (1-t) + t * adversarial_distance_factor[1]
            
            (-output_distance / adf).backward()
            optimizer.step()
            optimizer.zero_grad()
            
            adversarial_distance = F.l1_loss(ad_image_crops, image_crops)
    
            (adversarial_distance).backward()
            optimizer.step()
            optimizer.zero_grad()
            
            entry = {"out-d": float(output_distance), "adv-d": float(adversarial_distance)}
            history.append(entry)

            with torch.no_grad():
                ad_image[:] = ad_image.clamp(0, 1)

    except KeyboardInterrupt:
        pass

    with torch.no_grad():
        ad_image[:] = ad_image.clamp(0, 1)

    ad_image = ad_image.detach().cpu().repeat(C, 1, 1)
    display(VF.to_pil_image(make_grid(image_crops)))
    display(VF.to_pil_image(make_grid(ad_image_crops)))
    display(VF.to_pil_image(make_grid(output_crops)))
    display(VF.to_pil_image(make_grid(ref_crops)))
    display(VF.to_pil_image(signed_to_image(make_grid(ref_crops - output_crops))))
    torch.cuda.empty_cache()
    return ad_image, pd.DataFrame(history)

def mix(a, b, t):
    return a * (1-t) + t * b;
    
torch.cuda.empty_cache()
ad_image, hist = create_adversarial(
    trainer.model,
    VF.to_tensor(PIL.Image.open("../datasets/extrusion/train/source/005.png"))
        #[:, 500:700, 500:700]
    ,
    steps=4000, 
    adversarial_distance_factor=lambda t: mix(1, 2, min(1, t*4)),
)
df = hist.copy()
df.index = pd.DatetimeIndex(df.index * 1_000_000_000)
display(df.resample("50s").mean().plot())
display(VF.to_pil_image(ad_image))

# whole image in strides

In [None]:
def create_adversarial(
        model: nn.Module,
        image: torch.Tensor,
        adversarial_distance_factor: float = 2.,
        steps: int = 1000,
        lr: float = 0.003,
        batch_size: int = 32,
        shape: Tuple[int, int] = (64, 64),
):
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
        device = p.device

    C, H, W = image.shape
    print(f"image: {W}x{H}")
    image = image[:1]
    ad_image = nn.Parameter(image.clone().to(device))
    history = []

    optimizer = torch.optim.AdamW([ad_image], lr=lr)
    
    #if not isinstance(adversarial_distance_factor, (list, tuple)):
    #    adversarial_distance_factor = tuple(adversarial_distance_factor)

    def yield_crops():
        t = 0.
        count = math.prod(image.shape[-2:]) / math.prod(shape[-2:])  
        with tqdm(total=steps * count) as progress:
            for step in range(steps):
                
                for image_crops, positions in iter_image_patches(
                        image, shape, stride=shape, batch_size=batch_size, with_pos=True,
                        random_offset=[s//3 for s in shape],
                ):
                    
                    ad_image_crops = torch.concat([
                        ad_image[:, y: y + shape[-2], x: x + shape[-1]].unsqueeze(0)
                        for y, x in positions
                    ])
                    
                    yield image_crops.to(device), ad_image_crops
                    progress.update(image_crops.shape[0])
                    
                    if history:
                        progress.desc = " ".join(f"{k}={v}" for k, v in history[-1].items())
                

    try:
        for image_crops, ad_image_crops in yield_crops():
            with torch.no_grad():
                ref_crops = model(image_crops.repeat(1, C, 1, 1)).clamp(0, 1)
            output_crops = model(ad_image_crops.repeat(1, C, 1, 1).clamp(0, 1)).clamp(0, 1)
    
            output_distance = F.l1_loss(output_crops, ref_crops)

            (-output_distance / adversarial_distance_factor).backward()
            optimizer.step()
            optimizer.zero_grad()
            
            adversarial_distance = F.l1_loss(ad_image_crops, image_crops)
    
            (adversarial_distance).backward()
            optimizer.step()
            optimizer.zero_grad()
            
            entry = {"out-d": float(output_distance), "adv-d": float(adversarial_distance)}
            history.append(entry)

            with torch.no_grad():
                ad_image[:] = ad_image.clamp(0, 1)

    except KeyboardInterrupt:
        pass

    with torch.no_grad():
        ad_image[:] = ad_image.clamp(0, 1)

    ad_image = ad_image.detach().cpu().repeat(C, 1, 1)
    display(VF.to_pil_image(make_grid(image_crops)))
    display(VF.to_pil_image(make_grid(ad_image_crops)))
    display(VF.to_pil_image(make_grid(output_crops)))
    display(VF.to_pil_image(make_grid(ref_crops)))
    display(VF.to_pil_image(signed_to_image(make_grid(ref_crops - output_crops))))
    torch.cuda.empty_cache()
    return ad_image, pd.DataFrame(history)

def mix(a, b, t):
    return a * (1-t) + t * b;
    
torch.cuda.empty_cache()
ad_image, hist = create_adversarial(
    trainer.model,
    VF.to_tensor(PIL.Image.open("../datasets/extrusion/train/source/005.png"))
        #[:, 500:700, 500:700]
    ,
    steps=2000, 
    #adversarial_distance_factor=lambda t: mix(1, 2, min(1, t*4)),
)
df = hist.copy()
df.index = pd.DatetimeIndex(df.index * 1_000_000_000)
display(df.resample("50s").mean().plot())
display(VF.to_pil_image(ad_image))

# image crops with repeats

In [None]:
def create_adversarial(
        model: nn.Module,
        image: torch.Tensor,
        adversarial_distance_factor: float = 2.,
        steps: int = 1000,
        lr: float = 0.0003,
        batch_size: int = 8,
        repeats: int = 10,
        shape: Tuple[int, int] = (64, 64),
):
    model.eval()
    for p in model.parameters():
        p.requires_grad = False
        device = p.device

    C, H, W = image.shape
    print(f"image: {W}x{H}")
    image = image[:1]
    ad_image = nn.Parameter(image.clone().to(device))
    history = []

    optimizer = torch.optim.AdamW([ad_image], lr=lr)

    def yield_crops():
        with tqdm(total=steps * batch_size * repeats) as progress:
            for step in range(steps):
                batches = [[], []]
                positions = []
                for i in range(batch_size): 
                    x = random.randrange(W - shape[-1])
                    y = random.randrange(H - shape[-2])
                    #x = random.randrange(10)
                    #y = random.randrange(10)
                    batches[0].append(   image[:, y: y + shape[-2], x: x + shape[-1]].to(device))
                    batches[1].append(ad_image[:, y: y + shape[-2], x: x + shape[-1]])
                    positions.append((x, y))
                yield ( 
                    positions,
                    tuple(torch.concat([b.unsqueeze(0) for b in batch])
                      for batch in batches)
                )
                if history:
                    progress.desc = " ".join(f"{k}={v}" for k, v in history[-1].items())
                
                progress.update(batch_size * repeats)

    try:
        for positions, (image_crops, ad_image_crops) in yield_crops():
            with torch.no_grad():
                ref_crops = model(image_crops.repeat(1, C, 1, 1)).clamp(0, 1)
    
            for j in range(repeats):
                output_crops = model(ad_image_crops.repeat(1, C, 1, 1).clamp(0, 1)).clamp(0, 1)
        
                output_distance = F.l1_loss(output_crops, ref_crops)
                adversarial_distance = F.l1_loss(ad_image_crops, image_crops)
    
                loss = adversarial_distance - output_distance / adversarial_distance_factor
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                entry = {"out-d": float(output_distance), "adv-d": float(adversarial_distance)}
                history.append(entry)

            with torch.no_grad():
                ad_image[:] = ad_image.clamp(0, 1)
                #for (x, y), image_crop, ad_image_crop in zip(positions, image_crops, ad_image_crops):
                #    pad_ad_image_crop = image_crop.clone()
                #    #p = 8
                    #pad_ad_image_crop[:, p:-p, p:-p] = ad_image_crop[:, p:-p, p:-p]
                #    ad_image[:, y: y + shape[-2], x: x + shape[-1]] = pad_ad_image_crop
                
    except KeyboardInterrupt:
        pass
            
    ad_image = ad_image.detach().cpu().repeat(C, 1, 1)
    display(VF.to_pil_image(make_grid(image_crops)))
    display(VF.to_pil_image(make_grid(ad_image_crops)))
    display(VF.to_pil_image(make_grid(output_crops)))
    display(VF.to_pil_image(make_grid(ref_crops)))
    display(VF.to_pil_image(signed_to_image(make_grid(ref_crops - output_crops))))
    torch.cuda.empty_cache()
    return ad_image, pd.DataFrame(history)

torch.cuda.empty_cache()
ad_image, hist = create_adversarial(
    trainer.model,
    VF.to_tensor(PIL.Image.open("../datasets/extrusion/train/source/004.png")), #[:, 500:600, 500:600],
    steps=20000, 
)
df = hist.copy()
df.index = pd.DatetimeIndex(df.index * 1_000_000_000)
display(df.resample("100s").mean().plot())
display(VF.to_pil_image(ad_image))