In [2]:
import yaml 
from torchvision import transforms, datasets
import random
import os
import numpy as np
import networkx as nx


### Loading the imagenet graph

In [None]:
G = nx.DiGraph()

In [None]:
G.add_node('A')
G.add_node('B')
G.add_node('C')
G.add_node('D')

# Add edges
G.add_edge('A', 'B')
G.add_edge('A', 'C')
G.add_edge('B', 'D')

In [None]:

#nx.draw_networkx(G, nx.spring_layout(G))
G.remove_node('D')
nx.draw_networkx(G, nx.spring_layout(G))

In [None]:
with open("data/imagenette2/synset_human.txt", "r") as f:    
    synset_human_complete = f.read().splitlines()
    synset_human_complete = dict(line.split(maxsplit=1) for line in synset_human_complete)

In [None]:
from imagenet_classnames import name_map, folder_label_map

synset_human = folder_label_map
index_human = name_map

#check that synset_human keys are ordered
for i, (key_syn, key_index) in enumerate(zip(sorted(synset_human.keys()), sorted(index_human.keys()))):
    assert synset_human[key_syn] == index_human[key_index]
    assert i == key_index

index_synset = { i: k for i, k in enumerate(synset_human.keys())}

# validate the index_synset with index to human
for index in index_synset.keys():
    assert index_human[index] == synset_human[index_synset[index]]


In [None]:
with open('data/wordnet.is_a.txt', 'r') as f:
    data = f.read()
#convert from synset to human and save to file
with open('data/wordnet.is_a_human.txt', 'w') as f:
    for line in data.splitlines():
        synset1, synset2 = line.split()
        f.write(f'{synset_human_complete[synset2]} is a {synset_human_complete[synset1]}\n')

In [None]:
# Create an empty directed graph object
G_hum = nx.DiGraph()
G_syn = nx.DiGraph()

# Open the file and read its contents
with open('data/wordnet.is_a.txt', 'r') as f:
    data = f.read()

# Split the data into lines and iterate over each line
for line in data.split('\n'):
    # Split the line into parent and child node IDs
    if len(line) > 0:
        parent, child = line.split()
        if parent in synset_human: 
            #print("parent is a leaf in imagenet", parent, synset_human[parent], '\n')
            #print("the child is", child, synset_human_complete[child], '\n')
            assert child not in synset_human, "child is in imagenet"
            continue
        # Add an edge between the parent and child nodes
        G_hum.add_edge(synset_human_complete[parent], synset_human_complete[child])
        G_syn.add_edge(parent, child)

# Print the nodes and edges in the graph
print("Nodes:", sorted(G_syn.nodes()))
print(sorted(synset_human.keys()))
#print("Edges:", G1.edges())

In [None]:
len(synset_human)

In [None]:
# get duplicated human labels
from collections import Counter
duplicates = [k for k,v in Counter(synset_human.values()).items() if v>1]
duplicates


In [None]:
for key in synset_human.keys():
    assert key in G_syn.nodes(), "key is in imagenet but not in graph"  

In [None]:
# #remove the leaf nodes that are not in imagenet
# for node in list(G1.nodes):
#     if node not in human_synset and G1.out_degree(node) == 0:
#         if node == "volleyball player":
#             print("removing", node, '\n')
#         G1.remove_node(node)

# for node in list(G2.nodes):
#     if node not in human_synset and G2.out_degree(node) == 0:
#         G2.remove_node(node)


In [None]:
#recursively remove the leaf nodes that are not in imagenet
def remove_leaf_nodes(G):
    len_before = len(G.nodes())
    for node in list(G.nodes()):
        if G.out_degree(node) == 0 and node not in synset_human.values():
            G.remove_node(node)
            #print("removing node", node)
    len_after = len(G.nodes())
    print("len before", len_before, "len after", len_after)
    
    if len_before != len_after:
        remove_leaf_nodes(G)
    return G

#recursively remove the leaf nodes that are not in imagenet
def remove_leaf_node_syn(G):
    len_before = len(G.nodes())
    nodes_in_imagenet = [node for node in G.nodes() if node in synset_human]
    nodes_not_in_imagenet = [node for node in G.nodes() if node not in synset_human]
    total_nodes = len(nodes_in_imagenet)
    print("total nodes in imagenet", total_nodes)
    print("nodes in imagenet", len(nodes_in_imagenet))
    for node in nodes_not_in_imagenet:
        if G.out_degree(node) == 0:
            G.remove_node(node)
            #print("removing node", node)
    # for node in list(G.nodes()):

    #     if G.out_degree(node) == 0 and node not in synset_human:
    #         #print("removing node", node, synset_human[node])
    #         G.remove_node(node)
    #         #print("removing node", node)
    len_after = len(G.nodes())
    print("len before", len_before, "len after", len_after)
    
    if len_before != len_after:
        remove_leaf_node_syn(G)
    return G


In [None]:
G_hum_filtered = remove_leaf_nodes(G_hum)

In [None]:
G_syn_filtered = remove_leaf_node_syn(G_syn)

In [None]:
leaf_nodes =[node for node in list(G_syn.nodes) if G_syn.out_degree(node) == 0]

In [None]:
len(set(leaf_nodes).intersection(set(synset_human.keys()))), len(leaf_nodes), len(synset_human.keys())

In [None]:
from collections import defaultdict

In [None]:
def closest_leaf_nodes(G, orig_node, parents = None, k=4, siblings_ordered=defaultdict(list)):
    # if len(siblings_ordered[orig_node]) >= k:
    #     print("returning top")
    #     return siblings_ordered[orig_node][:k]
    # Get the parent node
    if parents is None:
        #print("parents is none")
        parents = list(G.predecessors(orig_node))
        if len(parents) == 0:
            #print("NO more parents")
            return None
   # print(f"Siblings ordered at this level:", '\n', f"{[synset_human_complete[sibling] for sibling in siblings_ordered[orig_node]]}")
    #print(f"current parents are: {[synset_human_complete[parent] for parent in parents]}")
    # Get all direct sibling leaf nodes
    curr_siblings = []
    for parent in parents:
        #print(f"Searching for Parent {synset_human_complete[parent]}:", '\n')
        leaves_of_parent = [n for n in nx.dfs_preorder_nodes(G, parent) if G.out_degree(n) == 0 and n != orig_node and n not in siblings_ordered[orig_node]]
        #print(f"Leaves of Parent {synset_human_complete[parent]} are:", '\n')
        #print(f"{[synset_human_complete[leaf] for leaf in leaves_of_parent]}")
        curr_siblings.extend(leaves_of_parent)
    #print(orig_node, curr_siblings)
    #print(f"Curr Siblings:", '\n', f"{[synset_human_complete[sibling] for sibling in curr_siblings]}")
    distances = {}
    for sibling in curr_siblings:
        distances[sibling] = nx.shortest_path_length(G.to_undirected(as_view=True), orig_node, sibling)
    
    # Sort the siblings by distance and return the closest k
    closest = sorted(distances, key=distances.get)[:k]
    #print(f"Curr Siblings sorted:", '\n', f"{[synset_human_complete[sibling] for sibling in curr_siblings]}")

    siblings_ordered[orig_node].extend(closest)
    #print(f"Siblings ordered at this level:", '\n', f"{[synset_human_complete[sibling] for sibling in siblings_ordered[orig_node]]}")
    
    if len(siblings_ordered[orig_node]) >= k:
        #print(" len is greater than k")
        siblings_ordered[orig_node] = siblings_ordered[orig_node][:k]
        #print(f"returning {siblings_ordered[orig_node][:k]}")
        return None #siblings_ordered[orig_node][:k]
    
    else:
        #print("exploring parents of parents")
        # If there are less than k siblings, go up to the parent's parent and try again
        # get parents of parents
        parents_of_parents = []
        for parent in parents:
            parents_of_parents.extend(list(G.predecessors(parent)))
        
        x = closest_leaf_nodes(G, orig_node, parents_of_parents, k, siblings_ordered)

        return 
            

In [None]:
def closest_leaf_nodes(G, orig_node, parents = None, k=4, siblings_ordered=defaultdict(list)):
    if parents is None:
        parents = list(G.predecessors(orig_node))
        if len(parents) == 0:
            return None
    curr_siblings = []
    for parent in parents:
        leaves_of_parent = [n for n in nx.dfs_preorder_nodes(G, parent) if G.out_degree(n) == 0 and n != orig_node and n not in siblings_ordered[orig_node]]
        curr_siblings.extend(leaves_of_parent)
    distances = {}
    for sibling in curr_siblings:
        distances[sibling] = nx.shortest_path_length(G.to_undirected(as_view=True), orig_node, sibling)
 
    closest = sorted(distances, key=distances.get)[:k]
 
    siblings_ordered[orig_node].extend(closest)


    if len(siblings_ordered[orig_node]) >= k:
        siblings_ordered[orig_node] = siblings_ordered[orig_node][:k]
        return None #siblings_ordered[orig_node][:k]
    
    else:

        parents_of_parents = []
        for parent in parents:
            parents_of_parents.extend(list(G.predecessors(parent)))
        
        x = closest_leaf_nodes(G, orig_node, parents_of_parents, k, siblings_ordered)
        return 
            

In [None]:
siblings_ordered = defaultdict(list)
closest_leaf_nodes(G_syn_filtered, 'n01484850', k=4, siblings_ordered=siblings_ordered)

In [None]:
print(siblings_ordered)

In [None]:
for i in siblings_ordered['n01484850']:
    print(i)
    print(synset_human_complete[i])
    print('\n')

In [None]:
siblings_ordered = defaultdict(list)
for i, node in enumerate(synset_human.keys(), 0):

    closest_leaf_nodes(G_syn_filtered, node, k=4, siblings_ordered=siblings_ordered)
    print("Node:", synset_human_complete[node], node)
    for sibling in siblings_ordered[node]:
        print("Sibling:", synset_human_complete[sibling])


In [None]:
# save the siblings_ordered dict


In [None]:
for node in x:
    print(synset_human_complete[node])

In [None]:
siblings_hum = {}
for key, value in siblings_ordered.items():
    siblings_hum[synset_human_complete[key]] = [synset_human_complete[i] for i in value]

In [None]:
su

In [None]:
#invert the dict index_synset to synset_index
synset_index = {v: k for k, v in index_synset.items()}

In [None]:
siblings_idx = {}
for key, value in siblings_ordered.items():
    siblings_idx[synset_index[key]] = [synset_index[i] for i in value]

In [None]:
with open('data/image_idx_to_tgt_class_closest_5.yaml', 'w') as file:
    documents = yaml.dump(dict(siblings_ordered), file)


with open('data/image_idx_to_tgt_class_closest_5_human.yaml', 'w') as file:

    documents = yaml.dump(dict(siblings_hum), file)

with open('data/image_idx_to_tgt_class_closest_5_idx.yaml', 'w') as file:
    
        documents = yaml.dump(dict(siblings_idx), file)

### get data from artifact 

In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('kifarid/cdiff/run-nfw1b6uy-dvce_video:v4', type='run_table')
df = pd.DataFrame(data=artifact.get("dvce_video").data, columns=artifact.get("dvce_video").columns)
memory_usage = df.memory_usage(deep=True).sum()
print(f"Memory usage of dataframe is {memory_usage/1e6} MB")
df.head()

### Creating the tgt_classes


In [None]:



with open('data/synset_closest_idx.yaml', 'r') as file:
    synset_closest_idx = yaml.safe_load(file)


data_path = '/misc/scratchSSD2/datasets/ILSVRC2012/val'
out_size = 256
transform_list = [
    transforms.Resize((out_size,out_size)),
    transforms.ToTensor()
]
transform = transforms.Compose(transform_list)
dataset = datasets.ImageFolder(data_path,  transform=transform)

idx_image_to_tgt_class = {}
for i in range(len(dataset)):
    img, label = dataset[i]
    #print(synset_closest_idx[label], random.choice(synset_closest_idx[label]))
    idx_image_to_tgt_class[i] = random.choice(synset_closest_idx[label])
    if i%100==0:
        print(f"current image index: {i}")

In [None]:
#convert default dict to dict


In [None]:
with open('data/image_idx_to_tgt_class.yaml', 'w') as file:
    documents = yaml.dump(dict(idx_image_to_tgt_class), file)

### Creating the new Imagenet wrapper 

In [3]:
with open('data/image_idx_to_tgt_class_closest_5.yaml', 'r') as file:
    image_idx_to_tgt_class_closest_5 = yaml.safe_load(file)

In [8]:
from data.imagenet_classnames import name_map, folder_label_map

In [9]:
class ImageNet(datasets.ImageFolder):
    classes = [name_map[i] for i in range(1000)]
    name_map = name_map

    def __init__(
            self, 
            root:str, 
            split:str="val", 
            transform=None, 
            target_transform=None, 
            class_idcs=None, 
            start_sample: int = 0, 
            end_sample: int = 50000//1000,
            return_tgt_cls: bool = False,
            idx_to_tgt_cls_path = None, 
            **kwargs
    ):
        _ = kwargs  # Just for consistency with other datasets.
        assert split in ["train", "val"]
        assert start_sample < end_sample and start_sample >= 0 and end_sample <= 50000//1000
        path = root if root[-3:] == "val" or root[-5:] == "train" else os.path.join(root, split)
        super().__init__(path, transform=transform, target_transform=target_transform)
        
        with open(idx_to_tgt_cls_path, 'r') as file:
            idx_to_tgt_cls = yaml.safe_load(file)
            if isinstance(idx_to_tgt_cls, dict):
                idx_to_tgt_cls = [idx_to_tgt_cls[i] for i in range(len(idx_to_tgt_cls))]
        self.idx_to_tgt_cls = idx_to_tgt_cls

        self.return_tgt_cls = return_tgt_cls

        if class_idcs is not None:
            class_idcs = list(sorted(class_idcs))
            tgt_to_tgt_map = {c: i for i, c in enumerate(class_idcs)}
            self.classes = [self.classes[c] for c in class_idcs]
            samples = []
            idx_to_tgt_cls = []
            for i, (p, t) in enumerate(self.samples):
                if t in tgt_to_tgt_map:
                    samples.append((p, tgt_to_tgt_map[t]))
                    idx_to_tgt_cls.append(self.idx_to_tgt_cls[i])
            
            self.idx_to_tgt_cls = idx_to_tgt_cls
            #self.samples = [(p, tgt_to_tgt_map[t]) for i, (p, t) in enumerate(self.samples) if t in tgt_to_tgt_map]
            self.class_to_idx = {k: tgt_to_tgt_map[v] for k, v in self.class_to_idx.items() if v in tgt_to_tgt_map}

        if "val" == split: # reorder
            new_samples = []
            idx_to_tgt_cls = []
            for idx in range(50000//1000):
                new_samples.extend(self.samples[idx::50000//1000])
                idx_to_tgt_cls.extend(self.idx_to_tgt_cls[idx::50000//1000])
            self.samples = new_samples[start_sample*1000:end_sample*1000]
            self.idx_to_tgt_cls = idx_to_tgt_cls[start_sample*1000:end_sample*1000]

        else:
            raise NotImplementedError

        self.class_labels = {i: folder_label_map[folder] for i, folder in enumerate(self.classes)}
        self.targets = np.array(self.samples)[:, 1]
    
    def __getitem__(self, index):
        sample = super().__getitem__(index)
        if self.return_tgt_cls:
            return *sample, self.idx_to_tgt_cls[index]
        else:
            return sample

In [None]:
image_idx_to_tgt_class_closest_5

In [10]:
#convert dict to list 
image_idx_to_tgt_class_closest_5_list = [ image_idx_to_tgt_class_closest_5[i] for i in range(len(image_idx_to_tgt_class_closest_5))]

KeyError: 0

In [15]:
ds = ImageNet('/misc/scratchSSD2/datasets/ILSVRC2012', idx_to_tgt_cls_path='data/image_idx_to_tgt.yaml', return_tgt_cls = True)


In [17]:
import torch 


In [18]:
 data_loader = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False, num_workers=1)

<function torch.utils.data._utils.collate.default_collate(batch)>

In [None]:
ds.classes[21]

In [None]:
ds[21][0]

In [None]:
ds.class_labels[21]

In [None]:
ds[3]

In [None]:
ds.classes[22]

In [None]:
out_size = 256
transform_list = [
    transforms.Resize((out_size, out_size)),
    transforms.ToTensor()
]
transform = transforms.Compose(transform_list)

In [None]:
ds = ImageNet('/misc/scratchSSD2/datasets/ILSVRC2012', split="val", return_tgt_cls = True, idx_to_tgt_cls=image_idx_to_tgt_class_closest_5_list, transform=transform)

In [None]:
ds[100]

In [None]:
print('test')

## New cone projection approach 

In [None]:
# get image tench_eaten
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import torch

path = 'tench_eaten.png'
img = Image.open(path)

img

In [None]:
def cone_project_chuncked(grad_temp_1, grad_temp_2, deg, chunk_size = 2):
    """
    grad_temp_1: gradient of the loss w.r.t. the robust/classifier free
    grad_temp_2: gradient of the loss w.r.t. the non-robust
    projecting the robust/CF onto the non-robust
    """
    orig_shp = (grad_temp_1.shape[0], 3, int((grad_temp_1.shape[-1]//3)**(1/2) ), int((grad_temp_1.shape[-1]//3)**(1/2) ))
    print(orig_shp)
    grad_temp_1_chuncked = grad_temp_1.view(*orig_shp) \
    .unfold(2, chunck_size, chunck_size) \
    .unfold(3, chunck_size, chunck_size) \
    .permute(0, 1, 4, 5, 2, 3) \
    .reshape(orig_shp[0], -1, orig_shp[-2]//chunck_size, orig_shp[-1]//chunck_size) \
    .permute(0, 2, 3, 1)
    
    grad_temp_2_chuncked = grad_temp_2.view(*orig_shp) \
    .unfold(2, chunck_size, chunck_size) \
    .unfold(3, chunck_size, chunck_size) \
    .permute(0, 1, 4, 5, 2, 3) \
    .reshape(orig_shp[0], -1, orig_shp[-2]//chunck_size, orig_shp[-1]//chunck_size) \
    .permute(0, 2, 3, 1)
   
    print(grad_temp_1_chuncked.shape, grad_temp_2_chuncked.shape)
    angles_before_chuncked = torch.acos((grad_temp_1_chuncked * grad_temp_2_chuncked).sum(-1) / (grad_temp_1_chuncked.norm(p=2, dim=-1) * grad_temp_2_chuncked.norm(p=2, dim=-1)))
    #print('angle before', angles_before_chuncked)
    grad_temp_2_chuncked_norm = grad_temp_2_chuncked / grad_temp_2_chuncked.norm(p=2, dim=-1).view(grad_temp_1_chuncked.shape[0], grad_temp_1_chuncked.shape[1], grad_temp_1_chuncked.shape[1], -1)
    #print(f" norm {grad_temp_2_chuncked_norm.norm(p=2, dim=-1) ** 2}")
    grad_temp_1_chuncked = grad_temp_1_chuncked - ((grad_temp_1_chuncked * grad_temp_2_chuncked_norm).sum(-1) / (grad_temp_2_chuncked_norm.norm(p=2, dim=-1) ** 2)).view(
         grad_temp_1_chuncked.shape[0], grad_temp_1_chuncked.shape[1], grad_temp_1_chuncked.shape[1], -1) * grad_temp_2_chuncked_norm

    grad_temp_1_chuncked_norm = grad_temp_1_chuncked / grad_temp_1_chuncked.norm(p=2, dim=-1).view(grad_temp_1_chuncked.shape[0], grad_temp_1_chuncked.shape[1], grad_temp_1_chuncked.shape[1], -1)
    radians = torch.tensor([deg], device=grad_temp_1_chuncked.device).deg2rad()
    cone_projection = grad_temp_2_chuncked.norm(p=2, dim=-1).unsqueeze(-1) * grad_temp_1_chuncked_norm * torch.tan(radians) + grad_temp_2_chuncked

    # second classifier is a non-robust one -
    # unless we are less than 45 degrees away - don't cone project
    #print(" ratio of dimensions that are cone projected: ", (angles_before > radians).float().mean())
    #print("angle before", angles_before.mean(), angles_before.std(), angles_before.min(), angles_before.max())
    #print("radians", radians)
    print(grad_temp_2_chuncked)

    grad_temp_chuncked = grad_temp_2_chuncked.clone()
    print(angles_before_chuncked > radians, grad_temp_1_chuncked.shape)
    grad_temp_chuncked[angles_before_chuncked > radians] = grad_temp_1_chuncked[angles_before_chuncked > radians] #cone_projection[angles_before_chuncked > radians]
    print(grad_temp_chuncked.shape)

    

    grad_temp = grad_temp_chuncked.permute(0, 3, 1, 2) \
    .reshape(orig_shp[0], orig_shp[1], 
    chunck_size, chunck_size,
    grad_temp_1_chuncked.shape[1], grad_temp_1_chuncked
    .shape[2]) \
    .permute(0, 1, 4, 2, 5, 3) \
    .reshape(*(orig_shp))
     
    print(angles_before_chuncked.shape)

    return grad_temp, angles_before_chuncked > radians

In [None]:
input_tensor_1 = torch.rand(1, 3, 4, 4).float()
input_tensor_2 = torch.rand(1, 3, 4, 4).float()
cone_project_chuncked(input_tensor_1.view(1, -1), input_tensor_2.view(1, -1), 45., chunk_size = 2)