In [47]:
from fastai.basics import *
from fastai.vision.all import *

from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm
from collections import defaultdict
import cmasher as cmr

ROOT = Path("..").resolve()
seed = 256

In [67]:
import matplotlib.font_manager as fm
from matplotlib.colors import LinearSegmentedColormap

fm.fontManager.addfont("/Users/john/Library/Fonts/Nunito-Italic.otf")
fm.fontManager.addfont("/Users/john/Library/Fonts/Nunito-Regular.otf")
fm.fontManager.addfont("/Users/john/Library/Fonts/Nunito-Bold.otf")
fm.fontManager.addfont("/Users/john/Library/Fonts/Nunito-ExtraBold.otf")

plt.rcParams['font.family'] = 'Nunito'
plt.rcParams['font.weight'] = "bold"
plt.rcParams['mathtext.fontset'] = 'custom'
plt.rcParams['mathtext.sf'] = 'Nunito'
plt.rcParams['mathtext.rm'] = 'Nunito'
plt.rcParams['mathtext.it'] = 'Nunito'
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=["#003f5c","#2f4b7c","#665191","#a05195","#d45087","#f95d6a","#ff7c43","#ffa600",])

# Top-k resnet18

In [3]:
K = 4

def RMSE(p, y): return torch.sqrt(MSELossFlat()(p, y))
    
class ResNetTopK(nn.Module):
    """Resnet18-like model with a single projection head at end, and a top-k 
    sparsity constraint in penultimate layer to encourage interpretability.
    """
    def __init__(self, k=32, n_out=1000, pretrained=True, **kwargs):
        super(ResNetTopK, self).__init__()
        if pretrained:
            self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT, **kwargs)
        else:
            self.resnet = resnet18(weights=None, **kwargs)
        self.k = k
        # change n_out features
        n_fc_in = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(n_fc_in, n_out)

    def forward(self, x):
        # regular convnet up to final layer
        features = nn.Sequential(*list(self.resnet.children())[:-1])(x)
        features = torch.flatten(features, 1)
        features = nn.functional.relu(features)
        
        # top-k constraint
        topk_values, topk_indices = torch.topk(features, k=self.k, dim=1)
        sparse_features = torch.zeros_like(features)
        sparse_features.scatter_(1, topk_indices, topk_values)
        
        # final fully connected layer
        x = self.resnet.fc(sparse_features)
        return x

In [4]:
df = pd.read_csv(ROOT / 'data/galaxies.csv', dtype={'objID': str})

df = df[
    (df.nii_6584_flux / df.nii_6584_flux_err  > 3)
    & (df.h_alpha_flux / df.h_alpha_flux_err > 3)
    & (df.oiii_5007_flux / df.oiii_5007_flux_err  > 3)
    & (df.h_beta_flux / df.h_beta_flux_err > 3)
    & (df.nii_6584_flux < 1e5)
    & (df.h_alpha_flux < 1e5)
    & (df.oiii_5007_flux < 1e5)
    & (df.h_beta_flux < 1e5)
].copy()

# df = df.sample(10000, random_state=256).copy()

n_galaxies = len(df)

# set a random state
rng = np.random.RandomState(seed)

In [5]:
# new targets
df["log_N2"] = np.log10(df.nii_6584_flux)
df["log_Ha"] = np.log10(df.h_alpha_flux)
df["log_O3"] = np.log10(df.oiii_5007_flux)
df["log_Hb"] = np.log10(df.h_beta_flux)

In [6]:
# fastai "data blocks" determine how data can be fed into a model
target = ['log_N2', 'log_Ha', 'log_O3', 'log_Hb']
dblock = DataBlock(
    blocks=(ImageBlock, RegressionBlock),
    get_x=ColReader('objID', pref=f'{ROOT}/data/images-sdss/', suff='.jpg'),
    get_y=ColReader(target),
    splitter=RandomSplitter(0.2, seed=seed),
    item_tfms=[Resize(160), CropPad(144)],
    batch_tfms=aug_transforms(do_flip=True, flip_vert=True, max_rotate=0, max_zoom=1.0, max_warp=0, p_lighting=0) + [Normalize()]
)

# "data loaders" actually load the data 
dls = ImageDataLoaders.from_dblock(dblock, df, bs=64)

In [7]:
cnn_model = ResNetTopK(k=K, n_out=len(target), pretrained=True).to("mps")

In [8]:
learn = Learner(
    dls,
    cnn_model,
    loss_func=RMSE,
    opt_func=ranger,
)

In [9]:
learn.fit_one_cycle(10, 0.1)

epoch,train_loss,valid_loss,time
0,0.329444,0.523534,26:13
1,0.29095,0.475089,25:47
2,0.276437,0.493881,25:27
3,0.270659,0.426233,24:43
4,0.274248,0.27713,24:52
5,0.266218,0.262335,24:56
6,0.255639,0.251537,24:57
7,0.25602,0.268484,24:58
8,0.250086,0.24831,25:03
9,0.241837,0.241771,25:09


In [10]:
model_path = f"{ROOT}/model/resnet18-topk_{K}-bpt_lines.pth"
torch.save(learn.model, model_path)

cnn_model = torch.load(model_path)

  cnn_model = torch.load(model_path)


# View activated features

In [11]:
def get_all_sparse_activations(loader, model):
    activations = []
    with torch.no_grad():
        layers = nn.Sequential(*list(model.resnet.children())[:-1], nn.Flatten())
        
        for xb, _ in tqdm(loader):    
            activations.append(layers(xb))
    return torch.concat(activations, 0).cpu().numpy()

In [12]:
activs = get_all_sparse_activations(dls.valid, cnn_model)
activs.shape

100%|█████████████████████████████████████████████████████████████████████| 782/782 [02:04<00:00,  6.26it/s]


(50041, 512)

In [14]:
activs_path = f"{ROOT}/results/resnet18-topk_{K}-bpt_lines/activations.npy"
np.save(activs_path, activs)

activations = np.load(activs_path)

In [15]:
(activations.max(0) > 0).sum()

24

In [16]:

# make a dictionary for every non-zero activated feature, where the
# key is the latent activation index, and the value is a list of tuples
# of (image index, activation strength)
feature_dict = defaultdict(list)

# Process each image's activations
for img_idx, img_activations in enumerate(activations):
    # Find non-zero activations
    non_zero = np.nonzero(img_activations)[0]

    # Add to dictionary
    for feature_idx in non_zero:
        activation_strength = img_activations[feature_idx]
        feature_dict[int(feature_idx)].append((int(img_idx), float(activation_strength)))

# Sort each list by activation strength in descending order
for feature_idx in feature_dict:
    feature_dict[feature_idx].sort(key=lambda x: x[1], reverse=True)


In [17]:
print([(k, len(feature_dict[k])) for k in feature_dict])

[(17, 31488), (138, 43517), (157, 47936), (322, 31500), (337, 25081), (399, 25497), (236, 29180), (242, 19432), (8, 5), (111, 2), (336, 60), (365, 2), (478, 3), (58, 1), (410, 1), (194, 1), (357, 1), (458, 4), (473, 4), (508, 1), (44, 5), (133, 1), (59, 2), (292, 1)]


In [42]:
valid_idx_to_objid = lambda idx: dls.valid.items.iloc[idx].objID

def plot_max_activating_galaxies(feature_dict, activation_index, top_n=5):
    galaxy_indices_and_activations = feature_dict[activation_index]

    top_n = min(top_n, len(galaxy_indices_and_activations))

    fig, axes = plt.subplots(1, top_n, figsize=(top_n * 1.5, 2), dpi=100, squeeze=0)
    axes = axes.reshape(-1)
    for ax, [galaxy_index, feature_activation] in zip(axes, galaxy_indices_and_activations):
        image = Image.open(f"{ROOT}/data/images-sdss/{valid_idx_to_objid(galaxy_index)}.jpg")
        ax.imshow(image, origin='lower')
        ax.set_title(f"{feature_activation:.4f}", fontsize=10)
        ax.axis("off")
    fig.suptitle(f"Activation {activation_index} ({len(galaxy_indices_and_activations)} galaxies)", fontsize=12)
    fig.subplots_adjust(left=0, right=1, top=0.8, wspace=0.02)

In [87]:
for k in tqdm(feature_dict):
    plot_max_activating_galaxies(feature_dict, k, top_n=10)
    plt.savefig(f"{ROOT}/results/resnet18-topk_{K}-bpt_lines/figures/{k}-examples.png")
    plt.close()

100%|███████████████████████████████████████████████████████████████████████| 24/24 [00:02<00:00, 10.82it/s]


# Show on BPT diagram normalized activation 

Only use common activations (i.e. ones that have non-zero values for 100+ cases)

In [101]:
plt.rcParams['font.weight'] = 700

for k in feature_dict:
    plt.figure(figsize=(4.8, 4), dpi=300)
    
    n2_ha = dls.valid.items.log_N2 - dls.valid.items.log_Ha
    o3_hb = dls.valid.items.log_O3 - dls.valid.items.log_Hb

    act_strength = activations[:, k] / activations[:, k].max()

    n2_ha = n2_ha.iloc[np.argsort(act_strength)]
    o3_hb = o3_hb.iloc[np.argsort(act_strength)]
    act_strength = act_strength[np.argsort(act_strength)]
    
    plt.scatter(
        n2_ha,
        o3_hb,
        c=act_strength, 
        edgecolors="none",
        cmap=cmr.ember,
        s=1,
        vmin=0,
        vmax=1,
    )
    cb = plt.colorbar()
    cb.set_label(label="Activation Strength", fontsize=12, fontfamily="Nunito", fontweight="bold")

    plt.title(f"Activation {k:>3} ($N$ = {len(feature_dict[k])})", fontsize=12, fontfamily="Nunito", fontweight="bold")
    
    plt.xlabel("log([NII]/H$\\alpha$)", fontsize=12, fontweight="bold")
    plt.ylabel("log([OIII]/H$\\beta$)", fontsize=12, fontweight="bold")
    # plt.legend(markerscale=10, loc="center right", framealpha=0, markerfirst=False, borderpad=0.05, handletextpad=0.05, title_fontsize=14)
    plt.grid(alpha=0.15)
    plt.xlim(-1.55, 0.55)
    plt.ylim(-1.05, 1.3)
    plt.savefig(f"{ROOT}/results/resnet18-topk_{K}-bpt_lines/figures/{k}-bpt_scatter.png")
    plt.close()