Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paper results #8

Closed
kltrock opened this issue Jul 16, 2022 · 2 comments
Closed

Paper results #8

kltrock opened this issue Jul 16, 2022 · 2 comments

Comments

@kltrock
Copy link

kltrock commented Jul 16, 2022

Very interesting work!

Hi, I am very interested in your work. I would like to ask how to replicate the results in table 3 in the main paper? More specifically, I want to study the effect of the two modules by switching on/off them as in table 3
Screen Shot
.

@dahyun-kang
Copy link
Owner

Hello kltrock

Thanks for your interest!
At this moment, I cannot provide the whole implementation of the ablation studies, but here are some code snippets for you to reproduce them. The two modules are basically plug-and-playable so you can simply take them off from the code to see their effects. All you need to do is to update some codes in renet/models/renet.py as the following steps and then it should work. All the command lines can be set as the same as provided.

  1. Remove the learnable SCR&CCA modules.

    renet/models/renet.py

    Lines 25 to 31 in 43a4e5a

    self.scr_module = self._make_scr_layer(planes=[640, 64, 64, 64, 640])
    self.cca_module = CCA(kernel_sizes=[3, 3], planes=[16, 1])
    self.cca_1x1 = nn.Sequential(
    nn.Conv2d(self.encoder_dim, 64, kernel_size=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU()
    )

    The above self.scr_module, self.cca_modul, self.cca_1x1 variables are the learnable parameters of SCR and CCA so you can comment them.

  2. Remove the attention process in CCA.

    renet/models/renet.py

    Lines 72 to 124 in 43a4e5a

    def cca(self, spt, qry):
    spt = spt.squeeze(0)
    # shifting channel activations by the channel mean
    spt = self.normalize_feature(spt)
    qry = self.normalize_feature(qry)
    # (S * C * Hs * Ws, Q * C * Hq * Wq) -> Q * S * Hs * Ws * Hq * Wq
    corr4d = self.get_4d_correlation_map(spt, qry)
    num_qry, way, H_s, W_s, H_q, W_q = corr4d.size()
    # corr4d refinement
    corr4d = self.cca_module(corr4d.view(-1, 1, H_s, W_s, H_q, W_q))
    corr4d_s = corr4d.view(num_qry, way, H_s * W_s, H_q, W_q)
    corr4d_q = corr4d.view(num_qry, way, H_s, W_s, H_q * W_q)
    # normalizing the entities for each side to be zero-mean and unit-variance to stabilize training
    corr4d_s = self.gaussian_normalize(corr4d_s, dim=2)
    corr4d_q = self.gaussian_normalize(corr4d_q, dim=4)
    # applying softmax for each side
    corr4d_s = F.softmax(corr4d_s / self.args.temperature_attn, dim=2)
    corr4d_s = corr4d_s.view(num_qry, way, H_s, W_s, H_q, W_q)
    corr4d_q = F.softmax(corr4d_q / self.args.temperature_attn, dim=4)
    corr4d_q = corr4d_q.view(num_qry, way, H_s, W_s, H_q, W_q)
    # suming up matching scores
    attn_s = corr4d_s.sum(dim=[4, 5])
    attn_q = corr4d_q.sum(dim=[2, 3])
    # applying attention
    spt_attended = attn_s.unsqueeze(2) * spt.unsqueeze(0)
    qry_attended = attn_q.unsqueeze(2) * qry.unsqueeze(1)
    # averaging embeddings for k > 1 shots
    if self.args.shot > 1:
    spt_attended = spt_attended.view(num_qry, self.args.shot, self.args.way, *spt_attended.shape[2:])
    qry_attended = qry_attended.view(num_qry, self.args.shot, self.args.way, *qry_attended.shape[2:])
    spt_attended = spt_attended.mean(dim=1)
    qry_attended = qry_attended.mean(dim=1)
    # In the main paper, we present averaging in Eq.(4) and summation in Eq.(5).
    # In the implementation, the order is reversed, however, those two ways become eventually the same anyway :)
    spt_attended_pooled = spt_attended.mean(dim=[-1, -2])
    qry_attended_pooled = qry_attended.mean(dim=[-1, -2])
    qry_pooled = qry.mean(dim=[-1, -2])
    similarity_matrix = F.cosine_similarity(spt_attended_pooled, qry_attended_pooled, dim=-1)
    if self.training:
    return similarity_matrix / self.args.temperature, self.fc(qry_pooled)
    else:
    return similarity_matrix / self.args.temperature

    The code snippet is precisely the CCA process. Replace the attention process starting from L80 with global average pooling for the both support and query.

I hope you can get it. Have a great day! 😃

@kltrock
Copy link
Author

kltrock commented Aug 4, 2022

Thank you very much. I have reproduced the results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants